├── .gitignore
├── Makefile
├── READMATH.md
├── README.md
├── ckpts
└── .gitkeep
├── data
└── .gitkeep
├── scripts
├── __init__.py
├── analyze.py
├── analyze_verify.py
├── check_noise_propagation.py
├── plot_losses.py
└── plot_preds_distn.py
├── src
├── __init__.py
├── attacks.py
├── datasets.py
├── lib
│ ├── __init__.py
│ ├── alexnet.py
│ ├── cifar10_selftrained.py
│ ├── classic_resnet.py
│ ├── ds_imagenet.py
│ ├── lenet.py
│ ├── wide_resnet.py
│ └── zipdata.py
├── models.py
├── noises
│ ├── __init__.py
│ ├── noises.py
│ ├── test_noises.py
│ └── utils.py
├── smooth.py
├── test.py
├── train.py
└── verify.py
├── svgs
├── 1418cce7d60743be1c545cd950367159.svg
├── 1fa8048512f84790ef174f591d0cb851.svg
├── 336fefe2418749fabf50594e52f7b776.svg
├── 44c65658d6cd134b1599c29b31949f77.svg
├── 5dc1880e644c7b3a0e9fa954759762ea.svg
├── 8244067f9118b85361c6645cc9f1c526.svg
├── 839a0dc412c4f8670dd1064e0d6d412f.svg
├── 8d2d1eabb21bb41807292151fe468472.svg
├── DistributionVenn.png
├── b52b48d8661f69776e1b6650998d5067.svg
├── bb9e6385ceb6d4a2d83a5b51a3c870c9.svg
├── bd5b313d1d74ae2fc57ddb870603d84b.svg
├── e1085464f81e12de4a74d54d14eb5dc5.svg
├── e703845884313f30712bfc7262a5e65b.svg
├── ec90b4fe342a37de851db6db2b08d4f4.svg
├── envelopes.png
├── robust-radii.png
└── table.png
└── tutorial.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pkl
2 | *.pyc
3 | *.torch
4 | *.npy
5 | *.csv
6 | *.eps
7 | *.swp
8 | *.pdf
9 | stdout.txt
10 | misc/
11 | jobs/
12 | data/*/*
13 | .vscode/
14 | .ipynb_checkpoints/
15 | pretrain/*
16 | Untitled*.ipynb
17 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | all:
2 | python3 -m readme2tex --branch master --nocdn --readme READMATH.md --output README.md
3 |
4 |
--------------------------------------------------------------------------------
/READMATH.md:
--------------------------------------------------------------------------------
1 | # Randomized Smoothing of All Shapes and Sizes
2 |
3 | Last update: July 2020.
4 |
5 | ---
6 |
7 | Code to accompany our paper:
8 |
9 | **Randomized Smoothing of All Shapes and Sizes**
10 | *Greg Yang\*, Tony Duan\*, J. Edward Hu, Hadi Salman, Ilya Razenshteyn, Jerry Li.*
11 | International Conference on Machine Learning (ICML), 2020 [[Paper]](https://arxiv.org/abs/2002.08118) [[Blog Post]](http://decentdescent.org/rs4a1.html)
12 |
13 | Notably, we outperform existing provably $\ell_1$-robust classifiers on ImageNet and CIFAR-10.
14 |
15 | 
16 |
17 | 
18 |
19 | This library implements the algorithms in our paper for computing robust radii for different smoothing distributions against different adversaries; for example, distributions of the form $e^{-\|x\|_\infty^k}$ against $\ell_1$ adversary.
20 |
21 | The following summarizes the (distribution, adversary) pairs covered here.
22 |
23 | 
24 |
25 | We can compare the certified robust radius each of these distributions implies at a fixed level of $\hat\rho_\mathrm{lower}$, the lower bound on the probability that the classifier returns the top class under noise. Here all noises are instantiated for CIFAR-10 dimensionality ($d=3072$) and normalized to variance $\sigma^2 \triangleq \mathbb{E}[\|x\|_2^2]=1$. Note that the first two rows below certify for the $\ell_1$ adversary while the last row certifies for the $\ell_2$ adversary and the $\ell_\infty$ adversary. For more details see our `tutorial.ipynb` notebook.
26 |
27 | 
28 |
29 | ## Getting Started
30 |
31 | Clone our repository and install dependencies:
32 |
33 | ```shell
34 | git clone https://github.com/tonyduan/rs4a.git
35 | conda create --name rs4a python=3.6
36 | conda activate rs4a
37 | conda install numpy matplotlib pandas seaborn
38 | conda install pytorch torchvision cudatoolkit=10.0 -c pytorch
39 | pip install torchnet tqdm statsmodels dfply
40 | ```
41 |
42 | ## Experiments
43 |
44 | To reproduce our SOTA $\ell_1$ results on CIFAR-10, we need to train models over
45 | $$
46 | \sigma \in \{0.15, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75,2.0,2.25, 2.5,2.75, 3.0,3.25,3.5\},
47 | $$
48 | For each value, run the following:
49 |
50 | ```shell
51 | python3 -m src.train
52 | --model=WideResNet
53 | --noise=Uniform
54 | --sigma={sigma}
55 | --experiment-name=cifar_uniform_{sigma}
56 |
57 | python3 -m src.test
58 | --model=WideResNet
59 | --noise=Uniform
60 | --sigma={sigma}
61 | --experiment-name=cifar_uniform_{sigma}
62 | --sample-size-cert=100000
63 | --sample-size-pred=64
64 | --noise-batch-size=512
65 | ```
66 |
67 | The training script will train the model via data augmentation for the specified noise and level of sigma, and save the model checkpoint to a directory `ckpts/experiment_name`.
68 |
69 | The testing script will load the model checkpoint from the `ckpts/experiment_name` directory, make predictions over the entire test set using the smoothed classifier, and certify the $\ell_1, \ell_2,$ and $\ell_\infty$ robust radii of these predictions. Note that by default we make predictions with $64$ samples, certify with $100,000$ samples, and at a failure probability of $\alpha=0.001$.
70 |
71 | To draw a comparison to the benchmark noises, re-run the above replacing `Uniform` with `Gaussian` and `Laplace`. Then to plot the figures and print the table of results (for $\ell_1$ adversary), run our analysis script:
72 |
73 | ```shell
74 | python3 -m scripts.analyze --dir=ckpts --show --adv=1
75 | ```
76 |
77 | Note that other noises will need to be instantiated with the appropriate arguments when the appropriate training/testing code is invoked. For example, if we want to sample noise $\propto \|x\|_\infty^{-100}e^{-\|x\|_\infty^{10}}$, we would run:
78 |
79 | ```shell
80 | python3 -m src.train
81 | --noise=ExpInf
82 | --k=10
83 | --j=100
84 | --sigma=0.5
85 | --experiment-name=cifar_expinf_0.5
86 | ```
87 |
88 | ## Trained Models
89 |
90 | Our pre-trained models are available.
91 |
92 | The following commands will download all models into the `pretrain/` directory.
93 |
94 | ```shell
95 | mkdir -p pretrain
96 | wget --directory-prefix=pretrain http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_all.zip
97 | unzip -d pretrain pretrain/cifar_all.zip
98 | wget --directory-prefix=pretrain http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_all.zip
99 | unzip -d pretrain pretrain/imagenet_all.zip
100 | ```
101 |
102 | ImageNet (ResNet-50): [[All Models, 2.3 GB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_all.zip)
103 |
104 | - Sigma=0.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_025.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_gaussian_025.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_laplace_025.pt)
105 | - Sigma=0.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_050.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_gaussian_050.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_laplace_050.pt)
106 | - Sigma=0.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_075.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_gaussian_075.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_laplace_075.pt)
107 | - Sigma=1.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_100.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_gaussian_100.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_laplace_100.pt)
108 | - Sigma=1.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_125.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_gaussian_125.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_laplace_125.pt)
109 | - Sigma=1.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_150.pt)
110 | - Sigma=1.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_175.pt)
111 | - Sigma=2.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_200.pt)
112 | - Sigma=2.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_225.pt)
113 | - Sigma=2.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_250.pt)
114 | - Sigma=2.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_275.pt)
115 | - Sigma=3.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_300.pt)
116 | - Sigma=3.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_325.pt)
117 | - Sigma=3.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_350.pt)
118 |
119 | CIFAR-10 (Wide ResNet 40-2): [[All Models, 226 MB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_all.zip)
120 |
121 | - Sigma=0.15: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_015.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_015.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_015.pt)
122 | - Sigma=0.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_025.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_025.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_025.pt)
123 | - Sigma=0.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_050.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_050.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_050.pt)
124 | - Sigma=0.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_075.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_075.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_075.pt)
125 | - Sigma=1.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_100.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_100.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_100.pt)
126 | - Sigma=1.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_125.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_125.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_125.pt)
127 | - Sigma=1.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_150.pt)
128 | - Sigma=1.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_175.pt)
129 | - Sigma=2.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_200.pt)
130 | - Sigma=2.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_225.pt)
131 | - Sigma=2.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_250.pt)
132 | - Sigma=2.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_275.pt)
133 | - Sigma=3.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_300.pt)
134 | - Sigma=3.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_325.pt)
135 | - Sigma=3.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_350.pt)
136 |
137 | By default the models above were trained with noise augmentation. We further improve upon our state-of-the-art certified accuracies using recent advances in training smoothed classifiers: (1) by using stability training (Li et al. NeurIPS 2019), and (2) by leveraging additional data using (a) pre-training on downsampled ImageNet (Hendrycks et al. NeurIPS 2019) and (b) semi-supervised self-training with data from 80 Million Tiny Images (Carmon et al. 2019). Our improved models trained with these methods are released below.
138 |
139 | ImageNet (ResNet 50):
140 |
141 | - Stability training: [[All Models, 2.3 GB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_stability.zip)
142 |
143 | CIFAR-10 (Wide ResNet 40-2):
144 |
145 | - Stability training: [[All Models, 234 MB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_stability.zip)
146 | - Stability training + pre-training: [[All Models, 236 MB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_pretrained.zip)
147 | - Stability training + semi-supervised learning: [[All Models, 235 MB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_semisup.zip)
148 |
149 | An example of pre-trained model usage is below. For more in depth example see our `tutorial.ipynb` notebook.
150 |
151 | ```python
152 | from src.models import WideResNet
153 | from src.noises import Uniform
154 | from src.smooth import *
155 |
156 | # load the model
157 | model = WideResNet(dataset="cifar", device="cuda")
158 | saved_dict = torch.load("pretrain/cifar_uniform_050.pt")
159 | model.load_state_dict(saved_dict)
160 | model.eval()
161 |
162 | # instantiation of noise
163 | noise = Uniform(device="cpu", dim=3072, sigma=0.5)
164 |
165 | # training code, to generate samples
166 | noisy_x = noise.sample(x)
167 |
168 | # testing code, certify for L1 adversary
169 | preds = smooth_predict_hard(model, x, noise, 64)
170 | top_cats = preds.probs.argmax(dim=1)
171 | prob_lb = certify_prob_lb(model, x, top_cats, 0.001, noise, 100000)
172 | radius = noise.certify(prob_lb, adv=1)
173 | ```
174 |
175 | ## Repository
176 |
177 | 1. `ckpts/` is used to store experiment checkpoints and results.
178 | 2. `data/` is used to store image datasets.
179 | 4. `tables/` contains caches of pre-calculated tables of certified radii.
180 | 5. `src/` contains the main souce code.
181 | 6. `scripts/` contains the analysis and plotting code.
182 |
183 | Within the `src/` directory, the most salient files are:
184 |
185 | 1. `train.py` is used to train models and save to `ckpts/`.
186 | 2. `test.py` is used to test and compute robust certificates for $\ell_1,\ell_2,\ell_\infty$ adversaries.
187 | 3. `noises/test_noises.py` is a unit test for the noises we include. Run the test with
188 |
189 | ```python -m unittest src/noises/test_noises.py```
190 |
191 | Note that some tests are probabilistic and can fail occasionally.
192 | If so, rerun a few more times to make sure the failure is not persistent.
193 |
194 | 4. `noises/noises.py` is a library of noises derived for randomized smoothing.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Randomized Smoothing of All Shapes and Sizes
2 |
3 | Last update: July 2020.
4 |
5 | ---
6 |
7 | Code to accompany our paper:
8 |
9 | **Randomized Smoothing of All Shapes and Sizes**
10 | *Greg Yang\*, Tony Duan\*, J. Edward Hu, Hadi Salman, Ilya Razenshteyn, Jerry Li.*
11 | International Conference on Machine Learning (ICML), 2020 [[Paper]](https://arxiv.org/abs/2002.08118) [[Blog Post]](http://decentdescent.org/rs4a1.html)
12 |
13 | Notably, we outperform existing provably
-robust classifiers on ImageNet and CIFAR-10.
14 |
15 | 
16 |
17 | 
18 |
19 | This library implements the algorithms in our paper for computing robust radii for different smoothing distributions against different adversaries; for example, distributions of the form
against
adversary.
20 |
21 | The following summarizes the (distribution, adversary) pairs covered here.
22 |
23 | 
24 |
25 | We can compare the certified robust radius each of these distributions implies at a fixed level of
, the lower bound on the probability that the classifier returns the top class under noise. Here all noises are instantiated for CIFAR-10 dimensionality (
) and normalized to variance
. Note that the first two rows below certify for the
adversary while the last row certifies for the
adversary and the
adversary. For more details see our `tutorial.ipynb` notebook.
26 |
27 | 
28 |
29 | ## Getting Started
30 |
31 | Clone our repository and install dependencies:
32 |
33 | ```shell
34 | git clone https://github.com/tonyduan/rs4a.git
35 | conda create --name rs4a python=3.6
36 | conda activate rs4a
37 | conda install numpy matplotlib pandas seaborn
38 | conda install pytorch torchvision cudatoolkit=10.0 -c pytorch
39 | pip install torchnet tqdm statsmodels dfply
40 | ```
41 |
42 | ## Experiments
43 |
44 | To reproduce our SOTA
results on CIFAR-10, we need to train models over
45 |

46 | For each value, run the following:
47 |
48 | ```shell
49 | python3 -m src.train
50 | --model=WideResNet
51 | --noise=Uniform
52 | --sigma={sigma}
53 | --experiment-name=cifar_uniform_{sigma}
54 |
55 | python3 -m src.test
56 | --model=WideResNet
57 | --noise=Uniform
58 | --sigma={sigma}
59 | --experiment-name=cifar_uniform_{sigma}
60 | --sample-size-cert=100000
61 | --sample-size-pred=64
62 | --noise-batch-size=512
63 | ```
64 |
65 | The training script will train the model via data augmentation for the specified noise and level of sigma, and save the model checkpoint to a directory `ckpts/experiment_name`.
66 |
67 | The testing script will load the model checkpoint from the `ckpts/experiment_name` directory, make predictions over the entire test set using the smoothed classifier, and certify the
and
robust radii of these predictions. Note that by default we make predictions with
samples, certify with
samples, and at a failure probability of
.
68 |
69 | To draw a comparison to the benchmark noises, re-run the above replacing `Uniform` with `Gaussian` and `Laplace`. Then to plot the figures and print the table of results (for
adversary), run our analysis script:
70 |
71 | ```shell
72 | python3 -m scripts.analyze --dir=ckpts --show --adv=1
73 | ```
74 |
75 | Note that other noises will need to be instantiated with the appropriate arguments when the appropriate training/testing code is invoked. For example, if we want to sample noise
, we would run:
76 |
77 | ```shell
78 | python3 -m src.train
79 | --noise=ExpInf
80 | --k=10
81 | --j=100
82 | --sigma=0.5
83 | --experiment-name=cifar_expinf_0.5
84 | ```
85 |
86 | ## Trained Models
87 |
88 | Our pre-trained models are available.
89 |
90 | The following commands will download all models into the `pretrain/` directory.
91 |
92 | ```shell
93 | mkdir -p pretrain
94 | wget --directory-prefix=pretrain http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_all.zip
95 | unzip -d pretrain pretrain/cifar_all.zip
96 | wget --directory-prefix=pretrain http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_all.zip
97 | unzip -d pretrain pretrain/imagenet_all.zip
98 | ```
99 |
100 | ImageNet (ResNet-50): [[All Models, 2.3 GB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_all.zip)
101 |
102 | - Sigma=0.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_025.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_gaussian_025.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_laplace_025.pt)
103 | - Sigma=0.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_050.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_gaussian_050.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_laplace_050.pt)
104 | - Sigma=0.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_075.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_gaussian_075.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_laplace_075.pt)
105 | - Sigma=1.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_100.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_gaussian_100.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_laplace_100.pt)
106 | - Sigma=1.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_125.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_gaussian_125.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_laplace_125.pt)
107 | - Sigma=1.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_150.pt)
108 | - Sigma=1.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_175.pt)
109 | - Sigma=2.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_200.pt)
110 | - Sigma=2.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_225.pt)
111 | - Sigma=2.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_250.pt)
112 | - Sigma=2.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_275.pt)
113 | - Sigma=3.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_300.pt)
114 | - Sigma=3.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_325.pt)
115 | - Sigma=3.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_uniform_350.pt)
116 |
117 | CIFAR-10 (Wide ResNet 40-2): [[All Models, 226 MB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_all.zip)
118 |
119 | - Sigma=0.15: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_015.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_015.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_015.pt)
120 | - Sigma=0.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_025.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_025.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_025.pt)
121 | - Sigma=0.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_050.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_050.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_050.pt)
122 | - Sigma=0.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_075.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_075.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_075.pt)
123 | - Sigma=1.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_100.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_100.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_100.pt)
124 | - Sigma=1.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_125.pt) [[Gaussian]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_gaussian_125.pt) [[Laplace]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_laplace_125.pt)
125 | - Sigma=1.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_150.pt)
126 | - Sigma=1.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_175.pt)
127 | - Sigma=2.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_200.pt)
128 | - Sigma=2.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_225.pt)
129 | - Sigma=2.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_250.pt)
130 | - Sigma=2.75: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_275.pt)
131 | - Sigma=3.0: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_300.pt)
132 | - Sigma=3.25: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_325.pt)
133 | - Sigma=3.5: [[Uniform]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_uniform_350.pt)
134 |
135 | By default the models above were trained with noise augmentation. We further improve upon our state-of-the-art certified accuracies using recent advances in training smoothed classifiers: (1) by using stability training (Li et al. NeurIPS 2019), and (2) by leveraging additional data using (a) pre-training on downsampled ImageNet (Hendrycks et al. NeurIPS 2019) and (b) semi-supervised self-training with data from 80 Million Tiny Images (Carmon et al. 2019). Our improved models trained with these methods are released below.
136 |
137 | ImageNet (ResNet 50):
138 |
139 | - Stability training: [[All Models, 2.3 GB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/imagenet_stability.zip)
140 |
141 | CIFAR-10 (Wide ResNet 40-2):
142 |
143 | - Stability training: [[All Models, 234 MB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_stability.zip)
144 | - Stability training + pre-training: [[All Models, 236 MB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_pretrained.zip)
145 | - Stability training + semi-supervised learning: [[All Models, 235 MB]](http://www.tonyduan.com/resources/2020_rs4a_ckpts/cifar_semisup.zip)
146 |
147 | An example of pre-trained model usage is below. For more in depth example see our `tutorial.ipynb` notebook.
148 |
149 | ```python
150 | from src.models import WideResNet
151 | from src.noises import Uniform
152 | from src.smooth import *
153 |
154 | # load the model
155 | model = WideResNet(dataset="cifar", device="cuda")
156 | saved_dict = torch.load("pretrain/cifar_uniform_050.pt")
157 | model.load_state_dict(saved_dict)
158 | model.eval()
159 |
160 | # instantiation of noise
161 | noise = Uniform(device="cpu", dim=3072, sigma=0.5)
162 |
163 | # training code, to generate samples
164 | noisy_x = noise.sample(x)
165 |
166 | # testing code, certify for L1 adversary
167 | preds = smooth_predict_hard(model, x, noise, 64)
168 | top_cats = preds.probs.argmax(dim=1)
169 | prob_lb = certify_prob_lb(model, x, top_cats, 0.001, noise, 100000)
170 | radius = noise.certify(prob_lb, adv=1)
171 | ```
172 |
173 | ## Repository
174 |
175 | 1. `ckpts/` is used to store experiment checkpoints and results.
176 | 2. `data/` is used to store image datasets.
177 | 4. `tables/` contains caches of pre-calculated tables of certified radii.
178 | 5. `src/` contains the main souce code.
179 | 6. `scripts/` contains the analysis and plotting code.
180 |
181 | Within the `src/` directory, the most salient files are:
182 |
183 | 1. `train.py` is used to train models and save to `ckpts/`.
184 | 2. `test.py` is used to test and compute robust certificates for
adversaries.
185 | 3. `noises/test_noises.py` is a unit test for the noises we include. Run the test with
186 |
187 | ```python -m unittest src/noises/test_noises.py```
188 |
189 | Note that some tests are probabilistic and can fail occasionally.
190 | If so, rerun a few more times to make sure the failure is not persistent.
191 |
192 | 4. `noises/noises.py` is a library of noises derived for randomized smoothing.
--------------------------------------------------------------------------------
/ckpts/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyduan/rs4a/a02831806ee4a117ef82ce5738fe8c5e8f180b06/ckpts/.gitkeep
--------------------------------------------------------------------------------
/data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyduan/rs4a/a02831806ee4a117ef82ce5738fe8c5e8f180b06/data/.gitkeep
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyduan/rs4a/a02831806ee4a117ef82ce5738fe8c5e8f180b06/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/analyze.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import os
4 | import pickle
5 | import matplotlib as mpl
6 | import seaborn as sns
7 | from argparse import ArgumentParser
8 | from collections import defaultdict
9 | from dfply import *
10 | from matplotlib import pyplot as plt
11 | from src.noises import *
12 | from src.datasets import get_dim
13 |
14 |
15 | if __name__ == "__main__":
16 |
17 | argparser = ArgumentParser()
18 | argparser.add_argument("--dir", default="./ckpts", type=str)
19 | argparser.add_argument("--debug", action="store_true")
20 | argparser.add_argument("--show", action="store_true")
21 | argparser.add_argument("--adv", default=1, type=float)
22 | argparser.add_argument("--eps-max", default=5.0, type=float)
23 | argparser.add_argument("--fancy-markers", action="store_true")
24 | args = argparser.parse_args()
25 | args.adv = round(args.adv)
26 |
27 | markers = ["o", "D", "s"] if args.fancy_markers else True
28 |
29 | sns.set_context("notebook", rc={"lines.linewidth": 2})
30 | sns.set_style("whitegrid")
31 | sns.set_palette("husl")
32 |
33 | dataset = args.dir.split("_")[0]
34 | experiment_names = list(filter(lambda s: os.path.isdir(args.dir + "/" + s),
35 | os.listdir(args.dir)))
36 |
37 | df = defaultdict(list)
38 | eps_range = np.linspace(0, args.eps_max, 81)
39 |
40 | for experiment_name in experiment_names:
41 |
42 | save_path = f"{args.dir}/{experiment_name}"
43 | results = {}
44 | experiment_args = pickle.load(open(f"{args.dir}/{experiment_name}/args.pkl", "rb"))
45 |
46 | for k in ("preds", "labels", f"radius_l{str(args.adv)}", "acc_train"):
47 | results[k] = np.load(f"{save_path}/{k}.npy")
48 |
49 | noise = parse_noise_from_args(experiment_args, device="cpu",
50 | dim=get_dim(experiment_args.dataset))
51 |
52 | top_1_preds = np.argmax(results["preds"], axis=1)
53 | top_1_acc_pred = (top_1_preds == results["labels"]).mean()
54 |
55 | if experiment_args.adversarial:
56 | noise_str = noise.plotstr() + f",$\\epsilon={experiment_args.eps}$"
57 | else:
58 | noise_str = noise.plotstr()
59 |
60 | for eps in eps_range:
61 |
62 | top_1_acc_cert = ((results[f"radius_l{str(args.adv)}"] >= eps) & \
63 | (top_1_preds == results["labels"])).mean()
64 | df["experiment_name"].append(experiment_name)
65 | df["sigma"].append(noise.sigma)
66 | df["noise"].append(noise_str)
67 | df["eps"].append(eps)
68 | df["top_1_acc_train"].append(results["acc_train"][0])
69 | df["top_1_acc_cert"].append(top_1_acc_cert)
70 | df["top_1_acc_pred"].append(top_1_acc_pred)
71 |
72 | # save the experiment results
73 | df = pd.DataFrame(df) >> arrange(X.noise)
74 | df.to_csv(f"{args.dir}/results_{dataset}_l{str(args.adv)}.csv", index=False)
75 |
76 | if args.debug:
77 | breakpoint()
78 |
79 | # print top-1 certified accuracies for table in paper
80 | print(df >> mask(X.eps.isin((0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0))) \
81 | >> group_by(X.eps, X.noise) >> arrange(X.top_1_acc_cert, ascending=False) >> head(1))
82 |
83 | # plot clean training accuracy against certified accuracy at eps
84 | # tmp = df >> mask(X.eps == 0.25) >> arrange(X.noise)
85 | # plt.figure(figsize=(3, 2.8))
86 | # ax = sns.scatterplot(x="top_1_acc_train", y="top_1_acc_cert", hue="noise", style="noise",
87 | # markers=markers, size="sigma", data=tmp, legend="full")
88 | # handles, labels = ax.get_legend_handles_labels()
89 | # i = [i for i, t in enumerate(ax.legend_.texts) if t.get_text() == "sigma"][0]
90 | # ax.legend(handles[:i], labels[:i])
91 | # plt.plot(np.linspace(0.0, 1.0), np.linspace(0.0, 1.0), "--", color="gray")
92 | # plt.ylim((0.2, 1.0))
93 | # plt.xlim((0.2, 1.0))
94 | # plt.xlabel("Top-1 training accuracy")
95 | # plt.ylabel("Top-1 certified accuracy, $\epsilon$ = 0.25")
96 | # plt.tight_layout()
97 | # plt.savefig(f"{args.dir}/train_vs_certified.pdf")
98 | #
99 | # tmp = df >> mask(X.eps.isin((0.25, 0.5, 0.75, 1.0))) >> \
100 | # mutate(tr=X.top_1_acc_train, cert=X.top_1_acc_cert)
101 | # fig = sns.relplot(data=tmp, kind="scatter", x="tr", y="cert",
102 | # hue="noise", col="eps", col_wrap=2, aspect=1, height=3, size="sigma")
103 | # fig.map_dataframe(plt.plot, (plt.xlim()[0], plt.xlim()[1]), (plt.xlim()[0], plt.xlim()[1]), 'k--').set_axis_labels("tr", "cert").add_legend()
104 | #
105 | # plot clean training and testing accuracy
106 | grouped = df >> group_by(X.experiment_name) \
107 | >> mask(X.sigma <= 1.25) \
108 | >> summarize(experiment_name=first(X.experiment_name),
109 | noise=first(X.noise),
110 | sigma=first(X.sigma),
111 | top_1_acc_train=first(X.top_1_acc_train),
112 | top_1_acc_pred=first(X.top_1_acc_pred))
113 |
114 | plt.figure(figsize=(6.5, 2.5))
115 | plt.subplot(1, 2, 1)
116 | sns.lineplot(x="sigma", y="top_1_acc_train", hue="noise", markers=markers,
117 | style="noise", data=grouped, alpha=1)
118 | plt.xlabel("$\sigma$")
119 | plt.ylabel("Top-1 training accuracy")
120 | plt.ylim((0, 1))
121 | plt.subplot(1, 2, 2)
122 | sns.lineplot(x="sigma", y="top_1_acc_pred", hue="noise", markers=markers,
123 | style="noise", data=grouped, alpha=1, legend=False)
124 | plt.xlabel("$\sigma$")
125 | plt.ylabel("Top-1 testing accuracy")
126 | plt.ylim((0, 1))
127 | plt.tight_layout()
128 | plt.savefig(f"{args.dir}/train_test_accuracies.pdf")
129 |
130 | # plot certified accuracies
131 | selected = df >> mutate(certacc=X.top_1_acc_cert)
132 | sns.relplot(x="eps", y="certacc", hue="noise", kind="line", col="sigma",
133 | data=selected, height=2, aspect=1.5, col_wrap=2)
134 | plt.ylim((0, 1))
135 | plt.tight_layout()
136 | plt.savefig(f"{args.dir}/per_sigma_l{str(args.adv)}.pdf")
137 |
138 | # plot top certified accuracy per epsilon, per type of noise
139 | grouped = df >> mask(X.noise != "Clean") \
140 | >> group_by(X.eps, X.noise) \
141 | >> arrange(X.top_1_acc_cert, ascending=False) \
142 | >> summarize(top_1_acc_cert=first(X.top_1_acc_cert),
143 | noise=first(X.noise))
144 |
145 | plt.figure(figsize=(3.0, 2.8))
146 | sns.lineplot(x="eps", y="top_1_acc_cert", data=grouped, hue="noise", style="noise")
147 | plt.ylim((0, 1))
148 | plt.xlabel(f"$\\ell_{str(args.adv)}$ radius")
149 | plt.ylabel("Top-1 certified accuracy")
150 | plt.tight_layout()
151 | plt.savefig(f"{args.dir}/certified_accuracies_l{str(args.adv)}.pdf", bbox_inches="tight")
152 |
153 | if args.show:
154 | plt.show()
155 |
156 |
--------------------------------------------------------------------------------
/scripts/analyze_verify.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import os
4 | import pickle
5 | import matplotlib as mpl
6 | import seaborn as sns
7 | from argparse import ArgumentParser
8 | from collections import defaultdict
9 | from dfply import *
10 | from matplotlib import pyplot as plt
11 |
12 |
13 | if __name__ == "__main__":
14 |
15 | argparser = ArgumentParser()
16 | argparser.add_argument("--experiment-name", default="cifar_uniform_05", type=str)
17 | argparser.add_argument("--dir", default="./ckpts", type=str)
18 | args = argparser.parse_args()
19 |
20 | sns.set_style("white")
21 | sns.set_palette("husl")
22 |
23 | df = defaultdict(list)
24 | eps_range = (3.0, 2.0, 1.0, 0.5, 0.25)
25 |
26 | save_path = f"{args.dir}/{args.experiment_name}"
27 | experiment_args = pickle.load(open(f"{args.dir}/{args.experiment_name}/args.pkl", "rb"))
28 | results = {}
29 |
30 | for k in ["preds_smooth", "radius_smooth", "labels"] + \
31 | [f"preds_adv_{eps}" for eps in eps_range]:
32 | results[k] = np.load(f"{save_path}/{k}.npy")
33 |
34 | top_1_preds_smooth = np.argmax(results["preds_smooth"], axis=1)
35 |
36 | for eps in eps_range:
37 |
38 | top_1_preds_adv = np.argmax(results[f"preds_adv_{eps}"], axis=1)
39 | top_1_acc_cert = ((results["radius_smooth"] >= eps) & \
40 | (top_1_preds_smooth == results["labels"])).mean()
41 | top_1_acc_adv = (top_1_preds_adv == results["labels"]).mean()
42 | df["eps"].append(eps)
43 | df["top_1_acc_cert"].append(top_1_acc_cert)
44 | df["top_1_acc_adv"].append(top_1_acc_adv)
45 |
46 | breakpoint()
47 | df = pd.DataFrame(df) >> gather("type", "top_1_acc", ["top_1_acc_cert", "top_1_acc_adv"])
48 |
49 | plt.figure(figsize=(5, 3))
50 | sns.lineplot(x="eps", y="top_1_acc", hue="type", data=df)
51 | plt.title(args.experiment_name)
52 | plt.ylim((0, 1))
53 | plt.tight_layout()
54 | plt.show()
55 |
56 |
--------------------------------------------------------------------------------
/scripts/check_noise_propagation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import seaborn as sns
4 | import torch
5 | import torch.nn.functional as F
6 | import os
7 | import itertools
8 | from argparse import ArgumentParser
9 | from collections import defaultdict
10 | from torch.utils.data import DataLoader, Subset
11 | from matplotlib import pyplot as plt
12 | from tqdm import tqdm
13 | from src.attacks import *
14 | from src.noises import *
15 | from src.models import *
16 | from src.datasets import get_dataset
17 |
18 |
19 | def get_final_layer_mlp(model, x):
20 | out = model.model[0](x.reshape(x.shape[0], -1))
21 | out = model.model[1](out)
22 | out = model.model[2](out)
23 | out = model.model[3](out)
24 | return out
25 |
26 | def get_final_layer(model, x):
27 | out = model.model.conv1(x)
28 | out = model.model.block1(out)
29 | out = model.model.block2(out)
30 | out = model.model.block3(out)
31 | out = model.model.relu(model.model.bn1(out))
32 | out = F.avg_pool2d(out, 8)
33 | return out.view(-1, model.model.nChannels)
34 |
35 | if __name__ == "__main__":
36 |
37 | argparser = ArgumentParser()
38 | argparser.add_argument("--device", default="cuda:0", type=str)
39 | argparser.add_argument("--batch-size", default=4, type=int),
40 | argparser.add_argument("--num-workers", default=os.cpu_count(), type=int)
41 | argparser.add_argument("--sample-size", default=64, type=int)
42 | argparser.add_argument("--dataset", default="cifar", type=str)
43 | argparser.add_argument("--dataset-skip", default=20, type=int)
44 | argparser.add_argument("--model", default="ResNet", type=str)
45 | argparser.add_argument("--dir", type=str, default="cifar_snapshots")
46 | argparser.add_argument("--load", action="store_true")
47 | args = argparser.parse_args()
48 |
49 | sns.set_style("whitegrid")
50 | sns.set_palette("husl")
51 |
52 | noises = ["Uniform", "Gaussian", "Laplace"]
53 | epochs = np.arange(1, 30, 1)
54 |
55 | test_dataset = get_dataset(args.dataset, "test")
56 | test_dataset = Subset(test_dataset, list(range(0, len(test_dataset), args.dataset_skip)))
57 | test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size,
58 | num_workers=args.num_workers)
59 |
60 | results = defaultdict(list)
61 |
62 | for noise_str, epoch in itertools.product(noises, epochs):
63 |
64 | if args.load:
65 | break
66 |
67 | sigma = 1.0
68 |
69 | save_path = f"{args.dir}/cifar_{noise_str}_{sigma}/{epoch-1}/model_ckpt.torch"
70 | model = eval(args.model)(dataset=args.dataset, device=args.device)
71 | model.load_state_dict(torch.load(save_path))
72 | model.eval()
73 |
74 | noise = eval(noise_str)(sigma=sigma, device=args.device, p=2, dim=get_dim(args.dataset))
75 |
76 | for i, (x, y) in tqdm(enumerate(test_loader), total=len(test_loader)):
77 |
78 | x, y = x.to(args.device), y.to(args.device)
79 |
80 | v = x.unsqueeze(1).expand((args.batch_size, args.sample_size, 3, 32, 32))
81 | v = v.reshape((-1, 3, 32, 32))
82 | noised = noise.sample(v)
83 | if args.model == "ResNet":
84 | rep_noisy = get_final_layer(model, noised)
85 | elif args.model == "MLP":
86 | rep_noisy = get_final_layer_mlp(model, noised)
87 | else:
88 | raise ValueError
89 | rep_noisy = rep_noisy.reshape(args.batch_size, -1, rep_noisy.shape[-1])
90 |
91 | top_cats = model(noised).reshape(args.batch_size, -1, 10).argmax(dim=2).mode(dim=1)
92 | top_cats = top_cats.values
93 |
94 | l2 = torch.stack([F.pdist(rep_i, p=2) for rep_i in rep_noisy]).mean(dim=1).data
95 | l1 = torch.stack([F.pdist(rep_i, p=1) for rep_i in rep_noisy]).mean(dim=1).data
96 | linf = torch.stack([F.pdist(rep_i, p=float("inf")) for rep_i in rep_noisy]).mean(dim=1).data
97 |
98 | results["acc"] += (y == top_cats).float().cpu().numpy().tolist()
99 | results["l1"] += l1.cpu().numpy().tolist()
100 | results["l2"] += l2.cpu().numpy().tolist()
101 | results["linf"] += linf.cpu().numpy().tolist()
102 | results["noise"] += args.batch_size * [noise_str]
103 | results["epoch"] += args.batch_size * [epoch]
104 |
105 | if args.load:
106 | results = pd.read_csv(f"{args.dir}/snapshots.csv")
107 | else:
108 | results = pd.DataFrame(results)
109 | results.to_csv(f"{args.dir}/snapshots.csv")
110 |
111 | plt.figure(figsize=(10, 6))
112 | plt.subplot(2, 2, 1)
113 | sns.lineplot(x="epoch", y="l2", hue="noise", data=results)
114 | plt.xlabel("Epoch")
115 | plt.ylabel("L2")
116 | plt.legend()
117 | plt.subplot(2, 2, 2)
118 | sns.lineplot(x="epoch", y="l1", hue="noise", data=results)
119 | plt.xlabel("Epoch")
120 | plt.ylabel("L1")
121 | plt.legend()
122 | plt.subplot(2, 2, 3)
123 | sns.lineplot(x="epoch", y="linf", hue="noise", data=results)
124 | plt.xlabel("Epoch")
125 | plt.ylabel("Linf")
126 | plt.legend()
127 | plt.subplot(2, 2, 4)
128 | sns.lineplot(x="epoch", y="acc", hue="noise", data=results)
129 | plt.xlabel("Epoch")
130 | plt.ylabel("Acc")
131 | plt.ylim((0, 1))
132 | plt.legend()
133 | plt.show()
134 |
135 |
--------------------------------------------------------------------------------
/scripts/plot_losses.py:
--------------------------------------------------------------------------------
1 | #
2 | # Plot the training loss for each model trained.
3 | #
4 | import numpy as np
5 | import pandas as pd
6 | import os
7 | import pickle
8 | import matplotlib as mpl
9 | import seaborn as sns
10 | from argparse import ArgumentParser
11 | from dfply import *
12 | from matplotlib import pyplot as plt
13 | from src.noises import *
14 | from src.datasets import get_dim
15 |
16 |
17 | if __name__ == "__main__":
18 |
19 | argparser = ArgumentParser()
20 | argparser.add_argument("--dir", default="./ckpts", type=str)
21 | args = argparser.parse_args()
22 |
23 | dataset = args.dir.split("_")[0]
24 | experiment_names = list(filter(lambda s: os.path.isdir(args.dir + "/" + s),
25 | os.listdir(args.dir)))
26 |
27 | sns.set_style("white")
28 | sns.set_palette("husl")
29 |
30 | losses_df = pd.DataFrame({"noise": [], "sigma": [], "losses_train": [], "iter": []})
31 |
32 | for experiment_name in experiment_names:
33 |
34 | save_path = f"{args.dir}/{experiment_name}"
35 | experiment_args = pickle.load(open(f"{args.dir}/{experiment_name}/args.pkl", "rb"))
36 | results = {}
37 |
38 | for k in ("losses_train",):
39 | results[k] = np.load(f"{save_path}/{k}.npy")
40 |
41 | noise = parse_noise_from_args(experiment_args, device="cpu",
42 | dim=get_dim(experiment_args.dataset))
43 |
44 | losses_df >>= bind_rows(pd.DataFrame({
45 | "experiment_name": experiment_name,
46 | "noise": noise.plotstr(),
47 | "sigma": experiment_args.sigma,
48 | "losses": results["losses_train"],
49 | "iter": np.arange(len(results["losses_train"]))}))
50 |
51 | # show training curves
52 | sns.relplot(x="iter", y="losses", hue="noise", data=losses_df, col="sigma",
53 | col_wrap=2, kind="line", height=1.5, aspect=3.5, alpha=0.5)
54 | plt.tight_layout()
55 | plt.show()
56 |
57 |
--------------------------------------------------------------------------------
/scripts/plot_preds_distn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy as sp
3 | import scipy.stats
4 | import pandas as pd
5 | import os
6 | import pickle
7 | import matplotlib as mpl
8 | import seaborn as sns
9 | from argparse import ArgumentParser
10 | from dfply import *
11 | from matplotlib import pyplot as plt
12 | from src.noises import parse_noise_from_args
13 | from src.datasets import get_dim
14 |
15 | if __name__ == "__main__":
16 |
17 | argparser = ArgumentParser()
18 | argparser.add_argument("--dir", default="./ckpts", type=str)
19 | argparser.add_argument("--target", default="prob_correct", type=str)
20 | argparser.add_argument("--adv", default=1, type=float)
21 | argparser.add_argument("--use-pdf", action="store_true")
22 | args = argparser.parse_args()
23 |
24 | dataset = args.dir.split("_")[0]
25 | experiment_names = list(filter(lambda x: x.startswith(dataset), os.listdir(args.dir)))
26 |
27 | sns.set_style("white")
28 | sns.set_palette("husl")
29 |
30 | losses_df = pd.DataFrame({"noise": [], "sigma": [], "losses_train": [], "iter": []})
31 |
32 | for experiment_name in experiment_names:
33 |
34 | save_path = f"{args.dir}/{experiment_name}"
35 | experiment_args = pickle.load(open(f"{args.dir}/{experiment_name}/args.pkl", "rb"))
36 | results = {}
37 |
38 | noise = parse_noise_from_args(experiment_args, device="cpu",
39 | dim=get_dim(experiment_args.dataset))
40 |
41 | for k in ("prob_lb", "preds", "labels", f"radius_l{str(args.adv)}"):
42 | results[k] = np.load(f"{save_path}/{k}.npy")
43 |
44 | p_correct = results["preds"][np.arange(len(results["preds"])),
45 | results["labels"].astype(int)]
46 | p_top = results["prob_lb"]
47 |
48 | if args.target == "prob_lower_bound":
49 | tgt = p_top
50 | axis = np.linspace(0, 1, 500)
51 | elif args.target == "prob_correct":
52 | tgt = p_correct
53 | axis = np.linspace(0, 1, 500)
54 | elif args.target == "radius":
55 | tgt = results[f"radius_l{str(args.adv)}"]
56 | tgt = tgt[~np.isnan(tgt)]
57 | axis = np.linspace(0, 4.0, 500)
58 | else:
59 | raise ValueError
60 |
61 | if args.use_pdf:
62 | cdf = sp.stats.gaussian_kde(tgt)(axis)
63 | else:
64 | cdf = (tgt < axis[:, np.newaxis]).mean(axis=1)
65 |
66 | losses_df >>= bind_rows(pd.DataFrame({
67 | "experiment_name": experiment_name,
68 | "noise": noise.plotstr(),
69 | "sigma": experiment_args.sigma,
70 | "cdf": cdf,
71 | "axis": axis}))
72 |
73 | # show training curves
74 | sns.relplot(x="axis", y="cdf", hue="noise", data=losses_df, col="sigma",
75 | col_wrap=2, kind="line", height=1.5, aspect=2.5, alpha=0.5)
76 | plt.show()
77 |
78 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyduan/rs4a/a02831806ee4a117ef82ce5738fe8c5e8f180b06/src/__init__.py
--------------------------------------------------------------------------------
/src/attacks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import grad
4 | from src.smooth import *
5 |
6 |
7 | def project_onto_ball(x, eps, p="inf"):
8 | """
9 | Note that projection onto inf-norm and 2-norm take O(d) time, and projection onto 1-norm
10 | takes O(dlogd) using the sorting-based algorithm given in [Duchi et al. 2008].
11 | """
12 | original_shape = x.shape
13 | x = x.view(x.shape[0], -1)
14 | assert not torch.isnan(x).any()
15 | if p == "inf":
16 | x = x.clamp(-eps, eps)
17 | elif p == 2:
18 | x = x.renorm(p=2, dim=0, maxnorm=eps)
19 | elif p == 1:
20 | mask = (torch.norm(x, p=1, dim=1) < eps).float().unsqueeze(1)
21 | mu, _ = torch.sort(torch.abs(x), dim=1, descending=True)
22 | cumsum = torch.cumsum(mu, dim=1)
23 | arange = torch.arange(1, x.shape[1] + 1, device=x.device)
24 | rho, _ = torch.max((mu * arange > (cumsum - eps)) * arange, dim=1)
25 | theta = (cumsum[torch.arange(x.shape[0]), rho.cpu() - 1] - eps) / rho
26 | proj = (torch.abs(x) - theta.unsqueeze(1)).clamp(min=0)
27 | x = mask * x + (1 - mask) * proj * torch.sign(x)
28 | else:
29 | raise ValueError("Can only project onto 1,2,inf norm balls.")
30 | return x.view(original_shape)
31 |
32 | def pgd_attack(model, x, y, eps, steps=20, adv="inf", clamp=(0, 1)):
33 | """
34 | Attack a model with PGD.
35 | """
36 | step_size = 2 * eps / steps
37 | x.requires_grad = True
38 | x_orig = x.clone().detach()
39 |
40 | for _ in range(steps):
41 | loss = model.loss(x, y).mean()
42 | grads = grad(loss, x)[0].reshape(x.shape[0], -1)
43 | if adv == 1:
44 | keep_vals = torch.kthvalue(grads.abs(), k=grads.shape[1] * 15 // 16, dim=1).values
45 | grads[torch.abs(grads) < keep_vals.unsqueeze(1)] = 0
46 | grads = torch.sign(grads)
47 | grads_norm = torch.norm(grads, dim=1, p=1)
48 | grads = grads / (grads_norm.unsqueeze(1) + 1e-8)
49 | elif adv == 2:
50 | grads_norm = torch.norm(grads, dim=1, p=2)
51 | grads = grads / (grads_norm.unsqueeze(1) + 1e-8)
52 | elif adv == "inf":
53 | grads = torch.sign(grads)
54 | else:
55 | raise ValueError
56 | diff = x + step_size * grads.reshape(x.shape) - x_orig
57 | diff = project_onto_ball(diff, eps, adv)
58 | x = (x_orig + diff).clamp(*clamp)
59 |
60 | loss = model.loss(x, y).mean()
61 | x = x.detach()
62 | x.requires_grad = False
63 | return x, loss
64 |
65 | def pgd_attack_smooth(model, x, y, eps, noise, sample_size, steps=20, adv="inf", clamp=(0, 1)):
66 | """
67 | Attack a smoothed model with PGD.
68 | """
69 | step_size = 2 * eps / steps
70 | x.requires_grad = True
71 | x_orig = x.clone().detach()
72 | rng = torch.cuda.get_rng_state_all()
73 |
74 | for _ in range(steps):
75 | torch.cuda.set_rng_state_all(rng)
76 | forecast = smooth_predict_soft(model, x, noise, sample_size)
77 | loss = -forecast.log_prob(y).mean()
78 | grads = grad(loss, x)[0].reshape(x.shape[0], -1)
79 | if adv == 1:
80 | keep_vals = torch.kthvalue(grads.abs(), k=grads.shape[1] * 15 // 16, dim=1).values
81 | grads[torch.abs(grads) < keep_vals.unsqueeze(1)] = 0
82 | grads = torch.sign(grads)
83 | grads_norm = torch.norm(grads, dim=1, p=1)
84 | grads = grads / (grads_norm.unsqueeze(1) + 1e-8)
85 | elif adv == 2:
86 | grads_norm = torch.norm(grads, dim=1, p=2)
87 | grads = grads / (grads_norm.unsqueeze(1) + 1e-8)
88 | elif adv == "inf":
89 | grads = torch.sign(grads)
90 | else:
91 | raise ValueError
92 | diff = x + step_size * grads.reshape(x.shape) - x_orig
93 | diff = project_onto_ball(diff, eps, adv)
94 | x = (x_orig + diff).clamp(*clamp)
95 | # forecast = smooth_predict_hard(model, x, noise, sample_size).probs
96 | # print(_, (torch.argmax(forecast, dim=1) == y).sum() / float(x.shape[0]),
97 | # diff.reshape(x.shape[0], -1).norm(dim=1, p=1).mean(),
98 | # diff.reshape(x.shape[0], -1).norm(dim=1, p=2).mean())
99 |
100 | torch.cuda.set_rng_state_all(rng)
101 | forecast = smooth_predict_soft(model, x, noise, sample_size)
102 | loss = -forecast.log_prob(y).mean()
103 |
104 | x = x.detach()
105 | x.requires_grad = False
106 | return x, loss
107 |
108 |
--------------------------------------------------------------------------------
/src/datasets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | from torchvision import datasets, transforms
4 | from torch.utils.data import ConcatDataset
5 | from src.lib.zipdata import ZipData
6 | from src.lib.ds_imagenet import DSImageNet
7 | from src.lib.cifar10_selftrained import CIFAR10SelfTrained
8 |
9 |
10 | def get_dim(name):
11 | if name == "cifar":
12 | return 3 * 32 * 32
13 | if name == "svhn":
14 | return 3 * 32 * 32
15 | if name == "ds-imagenet":
16 | return 3 * 32 * 32
17 | if name == "mnist":
18 | return 28 * 28
19 | if name == "imagenet":
20 | return 3 * 224 * 224
21 | if name == "fashion":
22 | return 28 * 28
23 | if name == "cifar10selftrained":
24 | return 3 * 32 * 32
25 |
26 | def get_num_labels(name):
27 | return 1000 if "imagenet" in name else 10
28 |
29 | def get_normalization_shape(name):
30 | if name == "cifar":
31 | return (3, 1, 1)
32 | if name == "imagenet":
33 | return (3, 1, 1)
34 | if name == "ds-imagenet":
35 | return (3, 1, 1)
36 | if name == "svhn":
37 | return (3, 1, 1)
38 | if name == "mnist":
39 | return (1, 1, 1)
40 | if name == "fashion":
41 | return (1, 1, 1)
42 | if name == "cifar10selftrained":
43 | return (3, 1, 1)
44 |
45 | def get_normalization_stats(name):
46 | if name == "cifar" or name == "cifar10selftrained":
47 | return {"mu": [0.4914, 0.4822, 0.4465], "sigma": [0.2023, 0.1994, 0.2010]}
48 | if name == "imagenet" or name == "ds-imagenet":
49 | return {"mu": [0.485, 0.456, 0.406], "sigma": [0.229, 0.224, 0.225]}
50 | if name == "svhn":
51 | return {"mu": [0.436, 0.442, 0.471], "sigma": [0.197, 0.200, 0.196]}
52 | if name == "mnist":
53 | return {"mu": [0.1307,], "sigma": [0.3081,]}
54 | if name == "fashion":
55 | return {"mu": [0.2849,], "sigma": [0.3516,]}
56 |
57 | def get_dataset(name, split):
58 |
59 | if name == "cifar" and split == "train":
60 | return datasets.CIFAR10("./data/cifar_10", train=True, download=True,
61 | transform=transforms.Compose([transforms.RandomCrop(32, padding=4),
62 | transforms.RandomHorizontalFlip(),
63 | transforms.ToTensor()]))
64 | if name == "cifar" and split == "test":
65 | return datasets.CIFAR10("./data/cifar_10", train=False, download=True,
66 | transform=transforms.ToTensor())
67 |
68 | if name == "imagenet" and split == "train":
69 | return ZipData("/mnt/bucket/imagenet/train.zip",
70 | "/mnt/bucket/imagenet/train_map.txt",
71 | transforms.Compose([transforms.RandomResizedCrop(224),
72 | transforms.RandomHorizontalFlip(),
73 | transforms.ToTensor()]))
74 |
75 | if name == "imagenet" and split == "test":
76 | return ZipData("/mnt/bucket/imagenet/val.zip",
77 | "/mnt/bucket/imagenet/val_map.txt",
78 | transforms.Compose([transforms.Resize(256),
79 | transforms.CenterCrop(224),
80 | transforms.ToTensor()]))
81 |
82 | if name == "ds-imagenet" and split == "train":
83 | return DSImageNet("/mnt/bucket/downsampled_imagenet/", split="train",
84 | transform=transforms.Compose([transforms.RandomHorizontalFlip(),
85 | transforms.ToTensor()]))
86 |
87 | if name == "ds-imagenet" and split == "test":
88 | return DSImageNet("/mnt/bucket/downsampled_imagenet/", split="test",
89 | transform=transforms.ToTensor())
90 |
91 | if name == "mnist":
92 | return datasets.MNIST("./data/mnist", train=(split == "train"), download=True,
93 | transform=transforms.ToTensor())
94 | if name == "fashion":
95 | return datasets.FashionMNIST("./data/fashion", train=(split == "train"), download=True,
96 | transform=transforms.ToTensor())
97 | if name == "svhn":
98 | return datasets.SVHN("./data/svhn", split=split, download=True,
99 | transform=transforms.ToTensor())
100 |
101 | if name == "cifar10selftrained" and split == "train":
102 | return ConcatDataset([CIFAR10SelfTrained("/mnt/bucket/cifar10_selftrained/ti_500K_pseudo_labeled.pickle",
103 | transform=transforms.Compose([transforms.RandomHorizontalFlip(),
104 | transforms.ToTensor()])),
105 | datasets.CIFAR10("./data/cifar_10", train=True, download=True,
106 | transform=transforms.Compose([transforms.RandomCrop(32, padding=4),
107 | transforms.RandomHorizontalFlip(),
108 | transforms.ToTensor()])),
109 | ])
110 |
111 | raise ValueError
112 |
--------------------------------------------------------------------------------
/src/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyduan/rs4a/a02831806ee4a117ef82ce5738fe8c5e8f180b06/src/lib/__init__.py
--------------------------------------------------------------------------------
/src/lib/alexnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class AlexNet(nn.Module):
7 |
8 | def __init__(self, num_classes, drop_rate=0.5):
9 | super().__init__()
10 | self.features = nn.Sequential(
11 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
12 | nn.ReLU(),
13 | nn.MaxPool2d(kernel_size=2),
14 | nn.Conv2d(64, 192, kernel_size=3, padding=1),
15 | nn.ReLU(),
16 | nn.MaxPool2d(kernel_size=2),
17 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
18 | nn.ReLU(),
19 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
20 | nn.ReLU(),
21 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
22 | nn.ReLU(),
23 | nn.MaxPool2d(kernel_size=2),
24 | )
25 | self.classifier = nn.Sequential(
26 | nn.Dropout(p=drop_rate),
27 | nn.Linear(256 * 2 * 2, 4096),
28 | nn.ReLU(),
29 | nn.Dropout(p=drop_rate),
30 | nn.Linear(4096, 4096),
31 | nn.ReLU(),
32 | nn.Linear(4096, num_classes),
33 | )
34 |
35 | def forward(self, x):
36 | x = self.features(x)
37 | x = x.view(x.size(0), 256 * 2 * 2)
38 | x = self.classifier(x)
39 | return x
40 |
41 |
--------------------------------------------------------------------------------
/src/lib/cifar10_selftrained.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import os
4 | import torch
5 | from PIL import Image
6 | from torch.utils.data import Dataset
7 |
8 |
9 | class CIFAR10SelfTrained(Dataset):
10 |
11 | def __init__(self, path, transform=None, target_transform=None):
12 | with open(path, "rb") as fd:
13 | self.dataset = pickle.load(fd)
14 | self.transform = transform
15 | self.target_transform = target_transform
16 |
17 | def __getitem__(self, index):
18 |
19 | img = self.dataset["data"][index]
20 | target = self.dataset["extrapolated_targets"][index]
21 |
22 | # doing this so that it is consistent with all other datasets
23 | # to return a PIL Image
24 | img = Image.fromarray(img)
25 |
26 | if self.transform is not None:
27 | img = self.transform(img)
28 |
29 | if self.target_transform is not None:
30 | target = self.target_transform(target)
31 |
32 | return img, target
33 |
34 | def __len__(self):
35 | return len(self.dataset["extrapolated_targets"])
36 |
37 |
--------------------------------------------------------------------------------
/src/lib/classic_resnet.py:
--------------------------------------------------------------------------------
1 | '''
2 | Properly implemented ResNet-s for CIFAR10 as described in paper [1].
3 |
4 | The implementation and structure of this file is hugely influenced by [2]
5 | which is implemented for ImageNet and doesn't have option A for identity.
6 | Moreover, most of the implementations on the web is copy-paste from
7 | torchvision's resnet and has wrong number of params.
8 |
9 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following
10 | number of layers and parameters:
11 |
12 | name | layers | params
13 | ResNet20 | 20 | 0.27M
14 | ResNet32 | 32 | 0.46M
15 | ResNet44 | 44 | 0.66M
16 | ResNet56 | 56 | 0.85M
17 | ResNet110 | 110 | 1.7M
18 | ResNet1202| 1202 | 19.4m
19 |
20 | which this implementation indeed has.
21 |
22 | Reference:
23 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
24 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
25 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
26 |
27 | If you use this implementation in you work, please don't forget to mention the
28 | author, Yerlan Idelbayev.
29 | '''
30 | import torch
31 | import torch.nn as nn
32 | import torch.nn.functional as F
33 | import torch.nn.init as init
34 |
35 | from torch.autograd import Variable
36 |
37 | __all__ = ['ResNet', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']
38 |
39 | def _weights_init(m):
40 | classname = m.__class__.__name__
41 | #print(classname)
42 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
43 | init.kaiming_normal_(m.weight)
44 |
45 | class LambdaLayer(nn.Module):
46 | def __init__(self, lambd):
47 | super(LambdaLayer, self).__init__()
48 | self.lambd = lambd
49 |
50 | def forward(self, x):
51 | return self.lambd(x)
52 |
53 |
54 | class BasicBlock(nn.Module):
55 | expansion = 1
56 |
57 | def __init__(self, in_planes, planes, stride=1, option='A'):
58 | super(BasicBlock, self).__init__()
59 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
60 | self.bn1 = nn.BatchNorm2d(planes)
61 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
62 | self.bn2 = nn.BatchNorm2d(planes)
63 |
64 | self.shortcut = nn.Sequential()
65 | if stride != 1 or in_planes != planes:
66 | if option == 'A':
67 | """
68 | For CIFAR10 ResNet paper uses option A.
69 | """
70 | self.shortcut = LambdaLayer(lambda x:
71 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
72 | elif option == 'B':
73 | self.shortcut = nn.Sequential(
74 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
75 | nn.BatchNorm2d(self.expansion * planes)
76 | )
77 |
78 | def forward(self, x):
79 | out = F.relu(self.bn1(self.conv1(x)))
80 | out = self.bn2(self.conv2(out))
81 | out += self.shortcut(x)
82 | out = F.relu(out)
83 | return out
84 |
85 |
86 | class ResNet(nn.Module):
87 | def __init__(self, block, num_blocks, num_classes=10):
88 | super(ResNet, self).__init__()
89 | self.in_planes = 16
90 |
91 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
92 | self.bn1 = nn.BatchNorm2d(16)
93 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
94 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
95 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
96 | self.linear = nn.Linear(64, num_classes)
97 |
98 | self.apply(_weights_init)
99 |
100 | def _make_layer(self, block, planes, num_blocks, stride):
101 | strides = [stride] + [1]*(num_blocks-1)
102 | layers = []
103 | for stride in strides:
104 | layers.append(block(self.in_planes, planes, stride))
105 | self.in_planes = planes * block.expansion
106 |
107 | return nn.Sequential(*layers)
108 |
109 | def forward(self, x):
110 | out = F.relu(self.bn1(self.conv1(x)))
111 | out = self.layer1(out)
112 | out = self.layer2(out)
113 | out = self.layer3(out)
114 | out = F.avg_pool2d(out, out.size()[3])
115 | out = out.view(out.size(0), -1)
116 | out = self.linear(out)
117 | return out
118 |
119 |
120 | def resnet20():
121 | return ResNet(BasicBlock, [3, 3, 3])
122 |
123 |
124 | def resnet32():
125 | return ResNet(BasicBlock, [5, 5, 5])
126 |
127 |
128 | def resnet44():
129 | return ResNet(BasicBlock, [7, 7, 7])
130 |
131 |
132 | def resnet56():
133 | return ResNet(BasicBlock, [9, 9, 9])
134 |
135 |
136 | def resnet110():
137 | return ResNet(BasicBlock, [18, 18, 18])
138 |
139 |
140 | def resnet1202():
141 | return ResNet(BasicBlock, [200, 200, 200])
142 |
143 |
144 | def test(net):
145 | import numpy as np
146 | total_params = 0
147 |
148 | for x in filter(lambda p: p.requires_grad, net.parameters()):
149 | total_params += np.prod(x.data.numpy().shape)
150 | print("Total number of params", total_params)
151 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters()))))
152 |
153 |
154 | if __name__ == "__main__":
155 | for net_name in __all__:
156 | if net_name.startswith('resnet'):
157 | print(net_name)
158 | test(globals()[net_name]())
159 | print()
160 |
161 |
--------------------------------------------------------------------------------
/src/lib/ds_imagenet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import os
4 | import torch
5 | from PIL import Image
6 | from torch.utils.data import Dataset
7 |
8 |
9 | # from https://github.com/hendrycks/pre-training
10 | class DSImageNet(Dataset):
11 | """`Downsampled ImageNet `_ Datasets.
12 | Args:
13 | root (string): Root directory of dataset where directory
14 | ``ImagenetXX_train`` exists.
15 | train (bool, optional): If True, creates dataset from training set, otherwise
16 | creates from test set.
17 | transform (callable, optional): A function/transform that takes in an PIL image
18 | and returns a transformed version. E.g, ``transforms.RandomCrop``
19 | target_transform (callable, optional): A function/transform that takes in the
20 | target and transforms it.
21 | """
22 | train_list = [
23 | ['train_data_batch_1', ''],
24 | ['train_data_batch_2', ''],
25 | ['train_data_batch_3', ''],
26 | ['train_data_batch_4', ''],
27 | ['train_data_batch_5', ''],
28 | ['train_data_batch_6', ''],
29 | ['train_data_batch_7', ''],
30 | ['train_data_batch_8', ''],
31 | ['train_data_batch_9', ''],
32 | ['train_data_batch_10', '']
33 | ]
34 |
35 | test_list = [
36 | ['val_data', ''],
37 | ]
38 |
39 | def __init__(self, root, split="train", transform=None, target_transform=None):
40 | self.root = os.path.expanduser(root)
41 | self.transform = transform
42 | self.target_transform = target_transform
43 | self.split = split # training set or test set
44 | self.base_folder = ""
45 |
46 | # if not self._check_integrity():
47 | # raise RuntimeError('Dataset not found or corrupted.') # TODO
48 |
49 | # now load the picked numpy arrays
50 | if split == "train":
51 | self.train_data = []
52 | self.train_labels = []
53 | for fentry in self.train_list:
54 | f = fentry[0]
55 | file = os.path.join(self.root, f)
56 | with open(file, 'rb') as fo:
57 | entry = pickle.load(fo)
58 | self.train_data.append(entry['data'])
59 | self.train_labels += [label - 1 for label in entry['labels']]
60 | self.mean = entry['mean']
61 |
62 | self.train_data = np.concatenate(self.train_data)
63 | self.train_data = self.train_data.reshape((self.train_data.shape[0], 3, 32, 32))
64 | self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC
65 | else:
66 | f = self.test_list[0][0]
67 | file = os.path.join(self.root, f)
68 | fo = open(file, 'rb')
69 | entry = pickle.load(fo)
70 | self.test_data = entry['data']
71 | self.test_labels = [label - 1 for label in entry['labels']]
72 | fo.close()
73 | self.test_data = self.test_data.reshape((self.test_data.shape[0], 3, 32, 32))
74 | self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC
75 |
76 | def __getitem__(self, index):
77 | """
78 | Args:
79 | index (int): Index
80 | Returns:
81 | tuple: (image, target) where target is index of the target class.
82 | """
83 | if self.split == "train":
84 | img, target = self.train_data[index], self.train_labels[index]
85 | else:
86 | img, target = self.test_data[index], self.test_labels[index]
87 |
88 | # doing this so that it is consistent with all other datasets
89 | # to return a PIL Image
90 | img = Image.fromarray(img)
91 |
92 | if self.transform is not None:
93 | img = self.transform(img)
94 |
95 | if self.target_transform is not None:
96 | target = self.target_transform(target)
97 |
98 | return img, target
99 |
100 | def __len__(self):
101 | if self.split == "train":
102 | return len(self.train_data)
103 | else:
104 | return len(self.test_data)
105 |
106 | def _check_integrity(self):
107 | root = self.root
108 | for fentry in (self.train_list + self.test_list):
109 | filename, md5 = fentry[0], fentry[1]
110 | fpath = os.path.join(root, self.base_folder, filename)
111 | if not check_integrity(fpath, md5):
112 | return False
113 | return True
114 |
115 |
116 |
--------------------------------------------------------------------------------
/src/lib/lenet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class LeNet(nn.Module):
6 |
7 | def __init__(self, in_channels, out_dim):
8 | super(LeNet, self).__init__()
9 | self.conv1 = nn.Conv2d(in_channels, 6, 5, padding=2)
10 | self.conv2 = nn.Conv2d(6, 16, 5)
11 | self.fc1 = nn.Linear(16*5*5, 120)
12 | self.fc2 = nn.Linear(120, 84)
13 | self.fc3 = nn.Linear(84, out_dim)
14 |
15 | def forward(self, x):
16 | out = F.relu(self.conv1(x))
17 | out = F.max_pool2d(out, 2)
18 | out = F.relu(self.conv2(out))
19 | out = F.max_pool2d(out, 2)
20 | out = out.view(out.size(0), -1)
21 | out = F.relu(self.fc1(out))
22 | out = F.relu(self.fc2(out))
23 | out = self.fc3(out)
24 | return out
25 |
26 |
--------------------------------------------------------------------------------
/src/lib/wide_resnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class BasicBlock(nn.Module):
8 |
9 | def __init__(self, in_planes, out_planes, stride, drop_rate=0.0):
10 | super(BasicBlock, self).__init__()
11 | self.bn1 = nn.BatchNorm2d(in_planes)
12 | self.relu1 = nn.ReLU(inplace=True)
13 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3,
14 | stride=stride, padding=1, bias=False)
15 | self.bn2 = nn.BatchNorm2d(out_planes)
16 | self.relu2 = nn.ReLU(inplace=True)
17 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3,
18 | stride=1, padding=1, bias=False)
19 | self.drop_rate = drop_rate
20 | self.equalInOut = (in_planes == out_planes)
21 | self.convShortcut = (not self.equalInOut) and \
22 | nn.Conv2d(in_planes, out_planes, kernel_size=1,
23 | stride=stride, padding=0, bias=False) or None
24 |
25 | def forward(self, x):
26 | if not self.equalInOut:
27 | x = self.relu1(self.bn1(x))
28 | else:
29 | out = self.relu1(self.bn1(x))
30 | if self.equalInOut:
31 | out = self.relu2(self.bn2(self.conv1(out)))
32 | else:
33 | out = self.relu2(self.bn2(self.conv1(x)))
34 | if self.drop_rate > 0:
35 | out = F.dropout(out, p=self.drop_rate, training=self.training)
36 | out = self.conv2(out)
37 | if not self.equalInOut:
38 | return torch.add(self.convShortcut(x), out)
39 | else:
40 | return torch.add(x, out)
41 |
42 |
43 | class NetworkBlock(nn.Module):
44 |
45 | def __init__(self, nb_layers, in_planes, out_planes, stride, drop_rate=0.0):
46 | super().__init__()
47 | layers = []
48 | for i in range(nb_layers):
49 | layers.append(BasicBlock(i == 0 and in_planes or \
50 | out_planes, out_planes, i == 0 and stride or 1, drop_rate))
51 | self.layer = nn.Sequential(*layers)
52 |
53 | def forward(self, x):
54 | return self.layer(x)
55 |
56 |
57 | class WideResNet(nn.Module):
58 |
59 | def __init__(self, depth, num_classes, widen_factor=1, drop_rate=0.0):
60 |
61 | super(WideResNet, self).__init__()
62 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
63 | assert ((depth - 4) % 6 == 0)
64 | n = (depth - 4) // 6
65 | block = BasicBlock
66 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False)
67 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], 1, drop_rate)
68 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], 2, drop_rate)
69 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], 2, drop_rate)
70 | self.bn1 = nn.BatchNorm2d(nChannels[3])
71 | self.relu = nn.ReLU(inplace=True)
72 | self.fc = nn.Linear(nChannels[3], num_classes)
73 | self.nChannels = nChannels[3]
74 |
75 | for m in self.modules():
76 | if isinstance(m, nn.Conv2d):
77 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
78 | m.weight.data.normal_(0, math.sqrt(2. / n))
79 | elif isinstance(m, nn.BatchNorm2d):
80 | m.weight.data.fill_(1)
81 | m.bias.data.zero_()
82 | elif isinstance(m, nn.Linear):
83 | m.bias.data.zero_()
84 |
85 | def forward(self, x):
86 | out = self.conv1(x)
87 | out = self.block1(out)
88 | out = self.block2(out)
89 | out = self.block3(out)
90 | out = self.relu(self.bn1(out))
91 | out = F.avg_pool2d(out, 8)
92 | out = out.view(-1, self.nChannels)
93 | return self.fc(out)
94 |
95 |
--------------------------------------------------------------------------------
/src/lib/zipdata.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import os.path as op
3 | from threading import local
4 | from zipfile import ZipFile, BadZipFile
5 |
6 | from PIL import Image
7 | from io import BytesIO
8 | import torch.utils.data as data
9 |
10 | _VALID_IMAGE_TYPES = ['.jpg', '.jpeg', '.tiff', '.bmp', '.png']
11 |
12 | class ZipData(data.Dataset):
13 | _IGNORE_ATTRS = {'_zip_file'}
14 |
15 | def __init__(self, path, map_file,
16 | transform=None, target_transform=None,
17 | extensions=None):
18 | self._path = path
19 | if not extensions:
20 | extensions = _VALID_IMAGE_TYPES
21 | self._zip_file = ZipFile(path)
22 | self.zip_dict = {}
23 | self.samples = []
24 | self.transform = transform
25 | self.target_transform = target_transform
26 | self.class_to_idx = {}
27 | with open(map_file, 'r') as f:
28 | for line in iter(f.readline, ""):
29 | line = line.strip()
30 | if not line:
31 | continue
32 | cls_idx = [l for l in line.split('\t') if l]
33 | if not cls_idx:
34 | continue
35 | assert len(cls_idx) >= 2, "invalid line: {}".format(line)
36 | idx = int(cls_idx[1])
37 | cls = cls_idx[0]
38 | del cls_idx
39 | at_idx = cls.find('@')
40 | assert at_idx >= 0, "invalid class: {}".format(cls)
41 | cls = cls[at_idx + 1:]
42 | if cls.startswith('/'):
43 | # Python ZipFile expects no root
44 | cls = cls[1:]
45 | assert cls, "invalid class in line {}".format(line)
46 | prev_idx = self.class_to_idx.get(cls)
47 | assert prev_idx is None or prev_idx == idx, "class: {} idx: {} previously had idx: {}".format(
48 | cls, idx, prev_idx
49 | )
50 | self.class_to_idx[cls] = idx
51 |
52 | for fst in self._zip_file.infolist():
53 | fname = fst.filename
54 | target = self.class_to_idx.get(fname)
55 | if target is None:
56 | continue
57 | if fname.endswith('/') or fname.startswith('.') or fst.file_size == 0:
58 | continue
59 | ext = op.splitext(fname)[1].lower()
60 | if ext in extensions:
61 | self.samples.append((fname, target))
62 | assert len(self), "No images found in: {} with map: {}".format(self._path, map_file)
63 |
64 | def __repr__(self):
65 | return 'ZipData({}, size={})'.format(self._path, len(self))
66 |
67 | def __getstate__(self):
68 | return {
69 | key: val if key not in self._IGNORE_ATTRS else None
70 | for key, val in self.__dict__.iteritems()
71 | }
72 |
73 | def __getitem__(self, index):
74 | proc = multiprocessing.current_process()
75 | pid = proc.pid # get pid of this process.
76 | if pid not in self.zip_dict:
77 | self.zip_dict[pid] = ZipFile(self._path)
78 | zip_file = self.zip_dict[pid]
79 |
80 | if index >= len(self) or index < 0:
81 | raise KeyError("{} is invalid".format(index))
82 | path, target = self.samples[index]
83 | try:
84 | sample = Image.open(BytesIO(zip_file.read(path))).convert('RGB')
85 | except BadZipFile:
86 | print("bad zip file")
87 | return None, None
88 | if self.transform is not None:
89 | sample = self.transform(sample)
90 | if self.target_transform is not None:
91 | target = self.target_transform(target)
92 | return sample, target
93 |
94 | def __len__(self):
95 | return len(self.samples)
96 |
97 |
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.distributions import Normal, Categorical, Bernoulli
5 | from torchvision import models as base_models
6 | from src.datasets import *
7 | from src.lib.wide_resnet import WideResNet as WideResNetBase
8 | from src.lib.alexnet import AlexNet as AlexNetBase
9 | from src.lib.lenet import LeNet as LeNetBase
10 | from src.lib.classic_resnet import resnet110 as classic_resnet110
11 |
12 |
13 | class Forecaster(nn.Module):
14 |
15 | def __init__(self, dataset, device):
16 | super().__init__()
17 | self.device = device
18 | self.norm = NormalizeLayer(get_normalization_shape(dataset), device,
19 | **get_normalization_stats(dataset))
20 |
21 | def forward(self, x):
22 | raise NotImplementedError
23 |
24 | def forecast(self, theta):
25 | return Categorical(logits=theta)
26 |
27 | def loss(self, x, y):
28 | forecast = self.forecast(self.forward(x))
29 | return -forecast.log_prob(y)
30 |
31 |
32 | class ResNet(Forecaster):
33 |
34 | def __init__(self, dataset, device):
35 | super().__init__(dataset, device)
36 | if dataset == "imagenet":
37 | self.model = nn.DataParallel(base_models.resnet50(num_classes=1000))
38 | else:
39 | self.model = nn.DataParallel(classic_resnet110())
40 | self.norm = nn.DataParallel(self.norm)
41 | self.norm = self.norm.to(device)
42 | self.model = self.model.to(device)
43 |
44 | def forward(self, x):
45 | x = self.norm(x)
46 | return self.model(x)
47 |
48 |
49 | class WideResNet(Forecaster):
50 |
51 | def __init__(self, dataset, device):
52 | super().__init__(dataset, device)
53 | self.model = nn.DataParallel(WideResNetBase(depth=40, widen_factor=2,
54 | num_classes=get_num_labels(dataset)))
55 | self.norm = nn.DataParallel(self.norm)
56 | self.norm = self.norm.to(device)
57 | self.model = self.model.to(device)
58 |
59 | def forward(self, x):
60 | x = self.norm(x)
61 | return self.model(x)
62 |
63 |
64 | class LinearModel(Forecaster):
65 |
66 | def __init__(self, dataset, device):
67 | super().__init__(dataset, device)
68 | self.model = nn.Linear(get_dim(dataset), get_num_labels(dataset))
69 | self.model = self.model.to(device)
70 |
71 | def forward(self, x):
72 | x = self.norm(x).view(x.shape[0], -1)
73 | return self.model(x)
74 |
75 |
76 | class AlexNet(Forecaster):
77 |
78 | def __init__(self, dataset, device):
79 | super().__init__(dataset, device)
80 | self.model = AlexNetBase(get_num_labels(dataset), drop_rate=0.5)
81 | self.model = self.model.to(device)
82 |
83 | def forward(self, x):
84 | x = self.norm(x)
85 | return self.model(x)
86 |
87 |
88 | class LeNet(Forecaster):
89 |
90 | def __init__(self, dataset, device):
91 | super().__init__(dataset, device)
92 | self.model = LeNetBase(get_normalization_shape(dataset)[0], get_num_labels(dataset))
93 | self.model = self.model.to(device)
94 |
95 | def forward(self, x):
96 | x = self.norm(x)
97 | return self.model(x)
98 |
99 |
100 | class MLP(Forecaster):
101 |
102 | def __init__(self, dataset, device):
103 | super().__init__(dataset, device)
104 | self.model = nn.Sequential(
105 | nn.Linear(get_dim(dataset), 2048),
106 | nn.ReLU(),
107 | nn.Linear(2048, 512),
108 | nn.ReLU(),
109 | nn.Linear(512, get_num_labels(dataset)))
110 | self.model = self.model.to(device)
111 |
112 | def forward(self, x):
113 | x = self.norm(x).view(x.shape[0], -1)
114 | return self.model(x)
115 |
116 |
117 | class NormalizeLayer(nn.Module):
118 | """
119 | Normalizes across the first non-batch axis.
120 |
121 | Examples:
122 | (64, 3, 32, 32) [CIFAR] => normalizes across channels
123 | (64, 8) [UCI] => normalizes across features
124 | """
125 | def __init__(self, dim, device, mu=None, sigma=None):
126 | super().__init__()
127 | self.dim = dim
128 | if mu and sigma:
129 | self.mu = nn.Parameter(torch.tensor(mu, device=device).reshape(dim), requires_grad=False)
130 | self.log_sig = nn.Parameter(torch.log(torch.tensor(sigma, device=device)).reshape(dim), requires_grad=False)
131 | self.initialized = True
132 | else:
133 | raise ValueError
134 |
135 | def forward(self, x):
136 | if not self.initialized:
137 | self.initialize_parameters(x)
138 | self.initialized = True
139 | return (x - self.mu) / torch.exp(self.log_sig)
140 |
141 | def initialize_parameters(self, x):
142 | with torch.no_grad():
143 | mu = x.view(x.shape[0], x.shape[1], -1).mean((0, 2))
144 | std = x.view(x.shape[0], x.shape[1], -1).std((0, 2))
145 | self.mu.copy_(mu.data.view(self.dim))
146 | self.log_sig.copy_(torch.log(std).data.view(self.dim))
147 |
148 |
--------------------------------------------------------------------------------
/src/noises/__init__.py:
--------------------------------------------------------------------------------
1 | from .noises import *
2 |
3 |
4 | def parse_noise_from_args(args, device, dim):
5 | """
6 | Given a Namespace of arguments, returns the constructed object.
7 | """
8 | kwargs = {
9 | "sigma": args.sigma,
10 | "lambd": args.lambd,
11 | "k": args.k,
12 | "j": args.j,
13 | "a": args.a
14 | }
15 | kwargs = {k: v for k, v in kwargs.items() if v is not None}
16 | return eval(args.noise)(device=device, dim=dim, **kwargs)
17 |
18 |
--------------------------------------------------------------------------------
/src/noises/test_noises.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 | import tqdm
4 | import torch
5 | import src.noises.noises as noises
6 |
7 | class TestSigma(unittest.TestCase):
8 |
9 | def test_sigma(self):
10 | rel_tol = 1e-2
11 | nsamples = int(1e4)
12 | dim = 3 * 32 * 32
13 | dev = 'cpu'
14 | configs = [
15 | dict(noise=noises.Uniform),
16 | dict(noise=noises.Gaussian),
17 | dict(noise=noises.Laplace),
18 | dict(noise=noises.UniformBall),
19 | ]
20 | for a in [3, 10, 100, 1000]:
21 | configs.append(
22 | dict(noise=noises.Pareto, a=a)
23 | )
24 | for k in [1, 2, 10, 50]:
25 | for j in [0, 1, 10, 100, 1000]:
26 | configs.append(
27 | dict(noise=noises.ExpInf, k=k, j=j)
28 | )
29 | for k in [1, 2, 10, 20]:
30 | configs.append(
31 | dict(noise=noises.Exp1, k=k)
32 | )
33 | for a in [10, 100, 1000]:
34 | configs.append(
35 | dict(noise=noises.PowInf, a=a+dim)
36 | )
37 | for k in [1, 2, 10, 100]:
38 | for j in [0, 1, 10, 100, 1000]:
39 | configs.append(
40 | dict(noise=noises.Exp2, k=k, j=j)
41 | )
42 | for k in [1, 2, 10, 100]:
43 | for a in [10, 100, 1000]:
44 | a = (dim + a) / k
45 | configs.append(
46 | dict(noise=noises.Pow2, k=k, a=a)
47 | )
48 | for p in [0.2, 0.5, 1, 2, 4, 8]:
49 | configs.append(dict(noise=noises.Expp, p=p))
50 | for c in tqdm.tqdm(configs):
51 | c['device'] = dev
52 | c['dim'] = dim
53 | c['sigma'] = 1
54 | with self.subTest(config=dict(c)):
55 | noisecls = c.pop('noise')
56 | noise = noisecls(**c)
57 | samples = noise.sample(torch.zeros(nsamples, dim))
58 | self.assertEqual(samples.shape, torch.Size((nsamples, dim)))
59 | emp_sigma = samples.std()
60 | self.assertAlmostEqual(emp_sigma, noise.sigma,
61 | delta=rel_tol * emp_sigma)
62 |
63 | class TestRadii(unittest.TestCase):
64 |
65 | def test_laplace_linf_radii(self):
66 | '''Test that the "approx" and "integrate" modes of linf certification
67 | for Laplace agree with each other.'''
68 | noise = noises.Laplace(3*32*32, sigma=1)
69 | cert1 = noise.certify_linf(torch.arange(0.5, 1, 0.01))
70 | cert2 = noise.certify_linf(torch.arange(0.5, 1, 0.01), 'integrate')
71 | self.assertTrue(np.allclose(cert1, cert2, rtol=1e-3))
72 |
73 | def test_exp2_l2_radii(self):
74 | r'''Test that for exp(-\|x\|_2), the differential and level set methods
75 | obtain similar robust radii.'''
76 | rs = torch.arange(0.5, 1, 0.01)
77 | with self.subTest(name='Exp2 test, k=1, j=0'):
78 | noise = noises.Exp2(3*32*32, sigma=1)
79 | cert1 = noise.certify_l2(rs)
80 | cert2 = noise.certify_l2_table(rs)
81 | self.assertTrue(np.allclose(cert1, cert2, rtol=1e-2))
82 | with self.subTest(name='Exp2 test, k=2, j=0'):
83 | noise = noises.Exp2(3*32*32, sigma=1, k=2)
84 | cert1 = noise.certify_l2(rs)
85 | cert2 = noise.certify_l2_table(rs)
86 | self.assertTrue(np.allclose(cert1, cert2, rtol=1e-3))
87 |
88 | def test_exp1_l1_radii(self):
89 | r'''Test that for exp(-\|x\|_1), the laplace and differential method
90 | table certification match.'''
91 | rs = torch.arange(0.5, 1, 0.01)
92 | noise1 = noises.Laplace(3*32*32, sigma=1)
93 | noise2 = noises.Exp1(3*32*32, sigma=1, k=1)
94 | cert1 = noise1.certify_l1(rs)
95 | cert2 = noise2.certify_l1_table(rs)
96 | self.assertTrue(np.allclose(cert1, cert2, rtol=1e-3))
97 |
98 | def test_expinf_linf_radii(self):
99 | r'''Test that ExpInf linf radii for k=1 matches known symbolic radii.'''
100 | rs = torch.arange(0.5, 1, 0.01)
101 | noise = noises.ExpInf(3*32*32, sigma=1, k=1)
102 | cert1 = noise.certify_linf(rs)
103 | cert2 = noise.certify_linf_table(rs)
104 | self.assertTrue(np.allclose(cert1, cert2, rtol=1e-3))
105 |
106 | def test_expp_largep_radii(self):
107 | r'''Test that Expp with p=1 and p=2 recovers Laplace and Gaussian.'''
108 | rs = torch.arange(0.5, 1, 0.01)
109 | with self.subTest(name='large p: Expp(p=1) vs Laplace'):
110 | noisep = noises.Expp(3*32*32, sigma=1, p=1)
111 | noise1 = noises.Laplace(3*32*32, sigma=1)
112 | cert1 = noisep.certify_l1(rs)
113 | cert2 = noise1.certify_l1(rs)
114 | self.assertTrue(np.allclose(cert1, cert2, rtol=1e-3))
115 |
116 | with self.subTest(name='large p: Expp(p=2) vs Gaussian'):
117 | noisep = noises.Expp(3*32*32, sigma=1, p=2)
118 | noise2 = noises.Gaussian(3*32*32, sigma=1)
119 | cert1 = noisep.certify_l1(rs)
120 | cert2 = noise2.certify_l1(rs)
121 | self.assertTrue(np.allclose(cert1, cert2, rtol=1e-3))
122 |
123 | def test_expp_smallp_radii(self):
124 | r'''Test the log convex* radii for Expp, p = 1, recovers Laplace,
125 | and for p = 0.5, recovers analytic expression.'''
126 | rs = torch.arange(0.5, 1, 0.01)
127 | with self.subTest(name='small p: Expp(p=1) vs Laplace'):
128 | noisep = noises.Expp(3*32*32, sigma=1, p=1)
129 | noise1 = noises.Laplace(3*32*32, sigma=1)
130 | cert1 = noisep.certify_l1_smallp_table(rs)
131 | cert2 = noise1.certify_l1(rs)
132 | self.assertTrue(np.allclose(cert1, cert2, rtol=1e-3))
133 |
134 | with self.subTest(name='small p: Expp(p=1/2), table vs closed form'):
135 | noise = noises.Expp(3*32*32, sigma=1, p=0.5)
136 | # analytic expression
137 | cert1 = noise.certify_l1(rs)
138 | # table
139 | cert2 = noise.certify_l1_smallp_table(rs)
140 | self.assertTrue(np.allclose(cert1, cert2, rtol=1e-3))
141 |
142 | def test_table_load(self):
143 | dim = 3 * 32 * 32
144 | dev = 'cpu'
145 | configs = []
146 | for k in [2, 3, 4]:
147 | configs.append(
148 | dict(noise=noises.Exp1, k=k, adv=1)
149 | )
150 | for k in [2]:
151 | for j in [2048, 3064, 3068, 3071]:
152 | configs.append(
153 | dict(noise=noises.Exp2, k=k, j=j, adv=2)
154 | )
155 | for k in [2]:
156 | for a in [1538, 1540, 1544]:
157 | configs.append(
158 | dict(noise=noises.Pow2, k=k, a=a, adv=2)
159 | )
160 | for k in [1, 2, 4, 8]:
161 | configs.append(
162 | dict(noise=noises.ExpInf, k=k, j=0, adv=np.inf)
163 | )
164 | for a in [4, 16, 32, 128]:
165 | configs.append(
166 | dict(noise=noises.PowInf, a=dim+a, adv=np.inf)
167 | )
168 | for p in [0.2, 0.5]:
169 | configs.append(dict(noise=noises.Expp, p=p, adv=1))
170 | rhos = torch.arange(0.5, 1, 0.01)
171 | for c in tqdm.tqdm(configs):
172 | c['device'] = dev
173 | c['dim'] = dim
174 | c['sigma'] = 1
175 | with self.subTest(config=dict(c)):
176 | noisecls = c.pop('noise')
177 | adv = c.pop('adv')
178 | noise = noisecls(**c)
179 | noise.certify(rhos, adv=adv)
180 |
181 | if __name__ == '__main__':
182 | unittest.main()
--------------------------------------------------------------------------------
/src/noises/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | from tqdm import tqdm
3 |
4 | import numpy as np
5 | import scipy as sp
6 | import scipy.special
7 | import scipy.stats
8 | import torch
9 | from scipy.stats import beta, binom, gamma, norm, laplace
10 | from torch.distributions import (Beta, Dirichlet, Gamma, Laplace, Normal,
11 | Pareto, Uniform)
12 |
13 |
14 | def atanh(x):
15 | return 0.5 * np.log((1 + x) / (1 - x))
16 |
17 | def relu(x):
18 | if isinstance(x, np.ndarray):
19 | return np.maximum(x, 0, x)
20 | else:
21 | return max(x, 0)
22 |
23 | def wfun(r, s, e, d):
24 | '''W function in the paper.
25 | Calculates the probability a point sampled from the surface of a ball
26 | of radius `r` centered at the origin is outside a ball of radius `s`
27 | with center `e` away from the origin.
28 | '''
29 | t = ((r+e)**2 - s**2)/(4*e*r)
30 | return beta((d-1)/2, (d-1)/2).cdf(t)
31 |
32 | def plexp(z, mode='lowerbound'):
33 | '''Computes LambertW(e^z) numerically safely.
34 | For small value of z, we use `scipy.special.lambertw`.
35 | For large value of z, we apply the approximation
36 |
37 | z - log(z) < W(e^z) < z - log(z) - log(1 - log(z)/z).
38 | '''
39 | if np.isscalar(z):
40 | if z > 500:
41 | if mode == 'lowerbound':
42 | return z - np.log(z)
43 | elif mode == 'upperbound':
44 | return z - np.log(z) - np.log(1 - np.log(z) / z)
45 | else:
46 | raise ValueError(f'Unknown mode: {mode}')
47 | else:
48 | return sp.special.lambertw(np.exp(z))
49 | else:
50 | if mode == 'lowerbound':
51 | # print(z)
52 | u = z - np.log(z)
53 | elif mode == 'upperbound':
54 | u = z - np.log(z) - np.log(1 - np.log(z) / z)
55 | else:
56 | raise ValueError(f'Unknown mode: {mode}')
57 | w = sp.special.lambertw(np.exp(z))
58 | w[z > 500] = u[z > 500]
59 | return w
60 |
61 |
62 | def sample_linf_sphere(device, shape):
63 | noise = (2 * torch.rand(shape, device=device) - 1
64 | ).reshape((shape[0], -1))
65 | sel_dims = torch.randint(noise.shape[1], size=(noise.shape[0],))
66 | idxs = torch.arange(0, noise.shape[0], dtype=torch.long)
67 | noise[idxs, sel_dims] = torch.sign(
68 | torch.rand(shape[0], device=device) - 0.5)
69 | return noise
70 |
71 | def sample_l2_sphere(device, shape):
72 | '''Sample uniformly from the unit l2 sphere.
73 | Inputs:
74 | device: 'cpu' | 'cuda' | other torch devices
75 | shape: a pair (batchsize, dim)
76 | Outputs:
77 | matrix of shape `shape` such that each row is a sample.
78 | '''
79 | noises = torch.randn(shape)
80 | noises /= noises.norm(dim=1, keepdim=True)
81 | return noises
82 |
83 | def sample_l1_sphere(device, shape):
84 | '''Sample uniformly from the unit l1 sphere, i.e. the cross polytope.
85 | Inputs:
86 | device: 'cpu' | 'cuda' | other torch devices
87 | shape: a pair (batchsize, dim)
88 | Outputs:
89 | matrix of shape `shape` such that each row is a sample.
90 | '''
91 | batchsize, dim = shape
92 | dirdist = Dirichlet(concentration=torch.ones(dim, device=device))
93 | noises = dirdist.sample([batchsize])
94 | signs = torch.sign(torch.rand_like(noises) - 0.5)
95 | return noises * signs
96 |
97 |
98 | def get_radii_from_table(table_rho, table_radii, prob_lb):
99 | prob_lb = prob_lb.numpy()
100 | idxs = np.searchsorted(table_rho, prob_lb, 'right') - 1
101 | return torch.tensor(table_radii[idxs], dtype=torch.float)
102 |
103 |
104 | def get_radii_from_convex_table(table_rho, table_radii, prob_lb):
105 | '''
106 | Assuming 1) radii is a convex function of rho and
107 | 2) table_rho[0] = 1/2, table_radii[0] = 0.
108 | Uses the basic fact that if f is convex and a < b, then
109 |
110 | f'(b) >= (f(b) - f(a)) / (b - a).
111 | '''
112 | prob_lb = prob_lb.numpy()
113 | idxs = np.searchsorted(table_rho, prob_lb, 'right') - 1
114 | slope = (table_radii[idxs] - table_radii[idxs-1]) / (
115 | table_rho[idxs] - table_rho[idxs-1]
116 | )
117 | rad = table_radii[idxs] + slope * (prob_lb - table_rho[idxs])
118 | rad[idxs == 0] = 0
119 | return torch.tensor(rad, dtype=torch.float)
120 |
121 |
122 | def diffmethod_table(Phi, inc, grid_type, upper, f):
123 | r'''Calculates a table of robust radii using the differential method.
124 | Given function Phi and the probability rho of correctly classifying an input
125 | perturbed by smoothing noise, the differential method gives
126 |
127 | \int_{1 - \rho}^{1/2} dp/\Phi(p)
128 |
129 | for the robust radius.
130 | Inputs:
131 | Phi: Phi function
132 | inc: grid increment (default: 0.001)
133 | grid_type: 'radius' | 'prob' (default: 'radius')
134 | In a `radius` grid, the probabilities rho are calculated as
135 |
136 | f([0, inc, 2 * inc, ..., upper - inc, upper]),
137 |
138 | where `f` and `upper` are additional inputs to this function.
139 | In a `prob` grid, the probabilities rho are spaced out evenly
140 | in increments of `inc`
141 |
142 | [1/2, 1/2 + inc, 1/2 + 2 * inc, ..., 1 - inc]
143 |
144 | upper: if `grid_type == 'radius'`, then the upper limit to the
145 | radius grid.
146 | f: the function used to determine the grid if `grid_type == 'radius'`
147 | Outputs:
148 | A Python dictionary `table` with
149 |
150 | table[rho] = radius
151 |
152 | for a grid of rho.
153 | '''
154 | table = {1/2: 0}
155 | lastrho = 1/2
156 | if grid_type == 'radius':
157 | rgrid = np.arange(inc, upper+inc, inc)
158 | grid = f(rgrid)
159 | elif grid_type == 'prob':
160 | grid = np.arange(1/2+inc, 1, inc)
161 | else:
162 | raise ValueError(f'Unknown grid_type {grid_type}')
163 | for rho in tqdm(grid):
164 | delta = sp.integrate.quad(lambda p: 1/Phi(p), 1 - rho, 1 - lastrho)[0]
165 | table[rho] = table[lastrho] + delta
166 | lastrho = rho
167 | return np.array(list(table.keys())), np.array(list(table.values()))
168 |
169 |
170 | def lvsetmethod_table(get_pbig, get_psmall, sigma, inc=0.01, upper=3):
171 | '''Calculates a table of robust radii using the level set method.
172 | Inputs:
173 | get_pbig: function for computing the big measure of a Neyman-Pearson set.
174 | get_psmall: same, for the small measure of a Neyman-Pearson set.
175 | sigma: sqrt(E[\|noise\|^2_2/d])
176 | inc: radius increment of the table
177 | upper: upper limit of radius for the table.
178 | Outputs:
179 | table_rho, table_radii
180 | '''
181 | def find_NP_log_ratio(u, x0=0, bracket=(-100, 100)):
182 | return sp.optimize.root_scalar(
183 | lambda t: get_pbig(t, u) - 0.5, x0=x0, bracket=bracket)
184 | table = {0: {'radius': 0, 'rho': 1/2}}
185 | prv_root = 0
186 | for eps in tqdm(np.arange(inc, upper + inc, inc)):
187 | e = eps * sigma
188 | t = find_NP_log_ratio(e, prv_root)
189 | table[eps] = {
190 | 't': t.root,
191 | 'radius': e,
192 | 'normalized_radius': eps,
193 | 'converged': t.converged,
194 | 'info': t
195 | }
196 | if t.converged:
197 | table[eps]['rho'] = 1 - get_psmall(t.root, e)
198 | prv_root = t.root
199 | return np.array([x['rho'] for x in table.values()]), \
200 | np.array([x['radius'] for x in table.values()])
201 |
202 | def make_or_load(basename, make, inc=0.001, grid_type='radius', upper=3,
203 | save=True, loc='tables'):
204 | '''Calculate or load a table of robust radii.
205 | First try to load a table under `./tables/` with the corresponding
206 | parameters. If this fails, calculate the table.
207 | Inputs:
208 | Phi: Phi function
209 | inc: grid increment (default: 0.001)
210 | grid_type: 'radius' | 'prob' (default: 'radius')
211 | In a `radius` grid, the probabilities rho are calculated as
212 |
213 | f([0, inc, 2 * inc, ..., upper - inc, upper]),
214 |
215 | where `f` and `upper` are additional inputs to this function.
216 | In a `prob` grid, the probabilities rho are spaced out evenly
217 | in increments of `inc`
218 |
219 | [1/2, 1/2 + inc, 1/2 + 2 * inc, ..., 1 - inc]
220 |
221 | upper: if `grid_type == 'radius'`, then the upper limit to the
222 | radius grid.
223 | f: the function used to determine the grid if `grid_type == 'radius'`
224 | Outputs:
225 | table_rho, table_radii
226 | '''
227 | from os.path import join
228 | if grid_type == 'radius':
229 | basename += f'_inc{inc}_grid{grid_type}_upper{upper}'
230 | else:
231 | basename += f'_inc{inc}_grid{grid_type}'
232 | rho_fname = join(loc, basename + '_rho.npy')
233 | radii_fname = join(loc, basename + '_radii.npy')
234 | try:
235 | table_rho = np.load(rho_fname)
236 | table_radii = np.load(radii_fname)
237 | print('Found and loaded saved table: ' + basename)
238 | except FileNotFoundError:
239 | print('Making robust radii table: ' + basename)
240 | table_rho, table_radii = make(inc=inc, grid_type=grid_type, upper=upper)
241 | if save:
242 | import os
243 | print('Saving robust radii table')
244 | os.makedirs(loc, exist_ok=True)
245 | np.save(rho_fname, table_rho)
246 | np.save(radii_fname, table_radii)
247 | return table_rho, table_radii
248 |
249 |
--------------------------------------------------------------------------------
/src/smooth.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.distributions import Categorical, Normal
5 | from statsmodels.stats.proportion import proportion_confint
6 |
7 |
8 | def direct_train_log_lik(model, x, y, noise, sample_size=16):
9 | """
10 | Log-likelihood for direct training (numerically stable with logusmexp trick).
11 | """
12 | samples_shape = torch.Size([x.shape[0], sample_size]) + x.shape[1:]
13 | samples = x.unsqueeze(1).expand(samples_shape)
14 | samples = samples.reshape(torch.Size([-1]) + samples.shape[2:])
15 | samples = noise.sample(samples)
16 | thetas = model.forward(samples).view(x.shape[0], sample_size, -1)
17 | return torch.logsumexp(thetas[torch.arange(x.shape[0]), :, y] - \
18 | torch.logsumexp(thetas, dim=2), dim=1) - \
19 | torch.log(torch.tensor(sample_size, dtype=torch.float, device=x.device))
20 |
21 | def smooth_predict_soft(model, x, noise, sample_size=64, noise_batch_size=512):
22 | """
23 | Make soft predictions for a model smoothed by noise.
24 |
25 | Returns
26 | -------
27 | predictions: Categorical, probabilities for each class returned by soft smoothed classifier
28 | """
29 | counts = None
30 | num_samples_left = sample_size
31 |
32 | while num_samples_left > 0:
33 |
34 | shape = torch.Size([x.shape[0], min(num_samples_left, noise_batch_size)]) + x.shape[1:]
35 | samples = x.unsqueeze(1).expand(shape)
36 | samples = samples.reshape(torch.Size([-1]) + samples.shape[2:])
37 | samples = noise.sample(samples.view(len(samples), -1)).view(samples.shape)
38 | logits = model.forward(samples).view(shape[:2] + torch.Size([-1]))
39 | if counts is None:
40 | counts = torch.zeros(x.shape[0], logits.shape[-1], dtype=torch.float, device=x.device)
41 | counts += F.softmax(logits, dim=-1).mean(dim=1)
42 | num_samples_left -= noise_batch_size
43 |
44 | return Categorical(probs=counts)
45 |
46 | def smooth_predict_hard(model, x, noise, sample_size=64, noise_batch_size=512):
47 | """
48 | Make hard predictions for a model smoothed by noise.
49 |
50 | Returns
51 | -------
52 | predictions: Categorical, probabilities for each class returned by hard smoothed classifier
53 | """
54 | counts = None
55 | num_samples_left = sample_size
56 |
57 | while num_samples_left > 0:
58 |
59 | shape = torch.Size([x.shape[0], min(num_samples_left, noise_batch_size)]) + x.shape[1:]
60 | samples = x.unsqueeze(1).expand(shape)
61 | samples = samples.reshape(torch.Size([-1]) + samples.shape[2:])
62 | samples = noise.sample(samples.view(len(samples), -1)).view(samples.shape)
63 | logits = model.forward(samples).view(shape[:2] + torch.Size([-1]))
64 | top_cats = torch.argmax(logits, dim=2)
65 | if counts is None:
66 | counts = torch.zeros(x.shape[0], logits.shape[-1], dtype=torch.float, device=x.device)
67 | counts += F.one_hot(top_cats, logits.shape[-1]).float().sum(dim=1)
68 | num_samples_left -= noise_batch_size
69 |
70 | return Categorical(probs=counts)
71 |
72 | def certify_prob_lb(model, x, top_cats, alpha, noise, sample_size=10**5, noise_batch_size=512):
73 | """
74 | Certify a probability lower bound (rho).
75 |
76 | Returns
77 | -------
78 | prob_lb: n-length tensor of floats
79 | """
80 | preds = smooth_predict_hard(model, x, noise, sample_size, noise_batch_size)
81 | top_probs = preds.probs.gather(dim=1, index=top_cats.unsqueeze(1)).detach().cpu()
82 | lower, _ = proportion_confint(top_probs * sample_size, sample_size, alpha=alpha, method="beta")
83 | lower = torch.tensor(lower.squeeze(), dtype=torch.float)
84 | return lower
85 |
86 | #def certify_smoothed(model, x, top_cats, alpha, noise, adv, sample_size=10**5, noise_batch_size=512):
87 | # """
88 | # Certify a smoothed model, given the top categories to certify for.
89 | #
90 | # Returns
91 | # -------
92 | # lower: n-length tensor of floats, the probability lower bounds
93 | # radius: n-length tensor
94 | # """
95 | # preds = smooth_predict_hard(model, x, noise, sample_size, noise_batch_size)
96 | # top_probs = preds.probs.gather(dim=1, index=top_cats.unsqueeze(1)).detach().cpu()
97 | # lower, _ = proportion_confint(top_probs * sample_size, sample_size, alpha=alpha, method="beta")
98 | # lower = torch.tensor(lower.squeeze(), dtype=torch.float)
99 | # return lower, noise.certify(lower, adv=adv)
100 | #
101 |
--------------------------------------------------------------------------------
/src/test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pathlib
3 | import os
4 | import sys
5 | import torch
6 | import torch.nn as nn
7 | from argparse import ArgumentParser
8 | from torchnet import meter
9 | from torch.utils.data import DataLoader, Subset
10 | from tqdm import tqdm
11 | from src.models import *
12 | from src.attacks import *
13 | from src.smooth import *
14 | from src.noises import *
15 | from src.datasets import *
16 |
17 |
18 | if __name__ == "__main__":
19 |
20 | argparser = ArgumentParser()
21 | argparser.add_argument("--device", default="cuda", type=str)
22 | argparser.add_argument("--batch-size", default=2, type=int)
23 | argparser.add_argument("--num-workers", default=min(os.cpu_count(), 8), type=int)
24 | argparser.add_argument("--sample-size-pred", default=64, type=int)
25 | argparser.add_argument("--sample-size-cert", default=100000, type=int)
26 | argparser.add_argument("--noise-batch-size", default=512, type=int)
27 | argparser.add_argument("--sigma", default=0.0, type=float)
28 | argparser.add_argument("--noise", default="Clean", type=str)
29 | argparser.add_argument("--k", default=None, type=int)
30 | argparser.add_argument("--j", default=None, type=int)
31 | argparser.add_argument("--a", default=None, type=int)
32 | argparser.add_argument("--lambd", default=None, type=float)
33 | argparser.add_argument("--dataset-skip", default=1, type=int)
34 | argparser.add_argument("--experiment-name", default="cifar", type=str)
35 | argparser.add_argument("--dataset", default="cifar", type=str)
36 | argparser.add_argument("--model", default="WideResNet", type=str)
37 | argparser.add_argument("--rotate", action="store_true")
38 | argparser.add_argument("--output-dir", type=str, default=os.getenv("PT_OUTPUT_DIR"))
39 | argparser.add_argument("--save-path", type=str, default=None)
40 | args = argparser.parse_args()
41 |
42 | test_dataset = get_dataset(args.dataset, "test")
43 | test_dataset = Subset(test_dataset, list(range(0, len(test_dataset), args.dataset_skip)))
44 | test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size,
45 | num_workers=args.num_workers)
46 |
47 | if not args.save_path:
48 | save_path = f"{args.output_dir}/{args.experiment_name}/model_ckpt.torch"
49 | else:
50 | save_path = args.save_path
51 |
52 | model = eval(args.model)(dataset=args.dataset, device=args.device)
53 | saved_dict = torch.load(save_path)
54 | model.load_state_dict(saved_dict)
55 | model.eval()
56 |
57 | noise = parse_noise_from_args(args, device=args.device, dim=get_dim(args.dataset))
58 |
59 | results = {
60 | "preds": np.zeros((len(test_dataset), get_num_labels(args.dataset))),
61 | "labels": np.zeros(len(test_dataset)),
62 | "prob_lb": np.zeros(len(test_dataset)),
63 | "preds_nll": np.zeros(len(test_dataset)),
64 | "radius_l1": np.zeros(len(test_dataset)),
65 | "radius_l2": np.zeros(len(test_dataset)),
66 | "radius_linf": np.zeros(len(test_dataset)),
67 | }
68 |
69 | for i, (x, y) in tqdm(enumerate(test_loader), total=len(test_loader)):
70 |
71 | x, y = x.to(args.device), y.to(args.device)
72 | x = rotate_noise.sample(x) if args.rotate else x
73 |
74 | preds = smooth_predict_hard(model, x, noise, args.sample_size_pred,
75 | noise_batch_size=args.noise_batch_size)
76 | top_cats = preds.probs.argmax(dim=1)
77 | prob_lb = certify_prob_lb(model, x, top_cats, 0.001, noise,
78 | args.sample_size_cert, noise_batch_size=args.noise_batch_size)
79 |
80 | lower, upper = i * args.batch_size, (i + 1) * args.batch_size
81 | results["preds"][lower:upper, :] = preds.probs.data.cpu().numpy()
82 | results["labels"][lower:upper] = y.data.cpu().numpy()
83 | results["prob_lb"][lower:upper] = prob_lb.cpu().numpy()
84 | results["radius_l1"][lower:upper] = noise.certify_l1(prob_lb).cpu().numpy()
85 | results["radius_l2"][lower:upper] = noise.certify_l2(prob_lb).cpu().numpy()
86 | results["radius_linf"][lower:upper] = noise.certify_linf(prob_lb).cpu().numpy()
87 | results["preds_nll"][lower:upper] = -preds.log_prob(y).cpu().numpy()
88 |
89 | save_path = f"{args.output_dir}/{args.experiment_name}"
90 | pathlib.Path(save_path).mkdir(parents=True, exist_ok=True)
91 | for k, v in results.items():
92 | np.save(f"{save_path}/{k}.npy", v)
93 |
94 | train_dataset = get_dataset(args.dataset, "train")
95 | train_loader = DataLoader(train_dataset, shuffle=False,
96 | batch_size=args.batch_size,
97 | num_workers=args.num_workers)
98 | acc_meter = meter.AverageValueMeter()
99 |
100 | for x, y in tqdm(train_loader):
101 |
102 | x, y = x.to(args.device), y.to(args.device)
103 | x = rotate_noise.sample(x) if args.rotate else x
104 |
105 | preds = smooth_predict_hard(model, x, noise, args.sample_size_pred,
106 | args.noise_batch_size)
107 | top_cats = preds.probs.argmax(dim=1)
108 | acc_meter.add(torch.sum(top_cats == y).cpu().data.numpy(), n=len(x))
109 |
110 | print("Training accuracy: ", acc_meter.value())
111 | save_path = f"{args.output_dir}/{args.experiment_name}"
112 | pathlib.Path(save_path).mkdir(parents=True, exist_ok=True)
113 | np.save(f"{save_path}/acc_train.npy", acc_meter.value())
114 |
115 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import pathlib
3 | import pickle
4 | import os
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | from argparse import ArgumentParser
10 | from torchnet import meter
11 | from torch.utils.data import DataLoader
12 | from tqdm import tqdm
13 | from src.models import *
14 | from src.noises import *
15 | from src.smooth import *
16 | from src.attacks import pgd_attack_smooth
17 | from src.datasets import get_dataset, get_dim
18 |
19 |
20 | if __name__ == "__main__":
21 |
22 | argparser = ArgumentParser()
23 | argparser.add_argument("--device", default="cuda", type=str)
24 | argparser.add_argument("--lr", default=0.1, type=float)
25 | argparser.add_argument("--batch-size", default=64, type=int)
26 | argparser.add_argument("--num-workers", default=min(os.cpu_count(), 8), type=int)
27 | argparser.add_argument("--num-epochs", default=120, type=int)
28 | argparser.add_argument("--print-every", default=20, type=int)
29 | argparser.add_argument("--save-every", default=50, type=int)
30 | argparser.add_argument("--experiment-name", default="cifar", type=str)
31 | argparser.add_argument("--noise", default="Clean", type=str)
32 | argparser.add_argument("--sigma", default=None, type=float)
33 | argparser.add_argument("--adv", default=2, type=int)
34 | argparser.add_argument("--eps", default=0.0, type=float)
35 | argparser.add_argument("--k", default=None, type=int)
36 | argparser.add_argument("--j", default=None, type=int)
37 | argparser.add_argument("--a", default=None, type=int)
38 | argparser.add_argument("--lambd", default=None, type=float)
39 | argparser.add_argument("--model", default="WideResNet", type=str)
40 | argparser.add_argument("--dataset", default="cifar", type=str)
41 | argparser.add_argument("--adversarial", action="store_true")
42 | argparser.add_argument("--stability", action="store_true")
43 | argparser.add_argument("--direct", action="store_true")
44 | argparser.add_argument("--save-path", type=str, default=None)
45 | argparser.add_argument('--output-dir', type=str, default=os.getenv("PT_OUTPUT_DIR"))
46 | args = argparser.parse_args()
47 |
48 | logging.basicConfig(level=logging.INFO)
49 | logger = logging.getLogger(__name__)
50 |
51 | model = eval(args.model)(dataset=args.dataset, device=args.device)
52 |
53 | # for fine-tuning a pre-trained model, we strip out the last fc layer
54 | if args.save_path:
55 | saved_dict = torch.load(args.save_path)
56 | del saved_dict["model.module.fc.weight"]
57 | del saved_dict["model.module.fc.bias"]
58 | model.load_state_dict(saved_dict, strict=False)
59 |
60 | model.train()
61 |
62 | train_loader = DataLoader(get_dataset(args.dataset, "train"),
63 | shuffle=True,
64 | batch_size=args.batch_size,
65 | num_workers=args.num_workers,
66 | pin_memory=False)
67 |
68 | optimizer = optim.SGD(model.parameters(),
69 | lr=args.lr,
70 | momentum=0.9,
71 | weight_decay=1e-4,
72 | nesterov=True)
73 | annealer = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.num_epochs)
74 |
75 | loss_meter = meter.AverageValueMeter()
76 | time_meter = meter.TimeMeter(unit=False)
77 |
78 | noise = parse_noise_from_args(args, device=args.device, dim=get_dim(args.dataset))
79 |
80 | train_losses = []
81 |
82 | for epoch in range(args.num_epochs):
83 |
84 | for i, (x, y) in enumerate(train_loader):
85 |
86 | x, y = x.to(args.device), y.to(args.device)
87 |
88 | if args.adversarial:
89 | eps = min(args.eps, epoch * args.eps / (args.num_epochs // 2))
90 | x, loss = pgd_attack_smooth(model, x, y, eps, noise, sample_size=4, adv=args.adv)
91 | elif args.stability:
92 | x_tilde = noise.sample(x.view(len(x), -1)).view(x.shape)
93 | x = noise.sample(x.view(len(x), -1)).view(x.shape)
94 | elif not args.direct:
95 | x = noise.sample(x.view(len(x), -1)).view(x.shape)
96 |
97 | if args.direct:
98 | loss = -direct_train_log_lik(model, x, y, noise, sample_size=16).mean()
99 | elif args.stability:
100 | pred_x = model.forecast(model.forward(x))
101 | pred_x_tilde = model.forecast(model.forward(x_tilde))
102 | loss = -pred_x.log_prob(y) + 6.0 * torch.distributions.kl_divergence(pred_x, pred_x_tilde)
103 | loss = loss.mean()
104 | elif not args.adversarial:
105 | loss = model.loss(x, y).mean()
106 |
107 | optimizer.zero_grad()
108 | loss.backward()
109 | optimizer.step()
110 | loss_meter.add(loss.cpu().data.numpy(), n=1)
111 |
112 | if i % args.print_every == 0:
113 | logger.info(f"Epoch: {epoch}\t"
114 | f"Itr: {i} / {len(train_loader)}\t"
115 | f"Loss: {loss_meter.value()[0]:.2f}\t"
116 | f"Mins: {(time_meter.value() / 60):.2f}\t"
117 | f"Experiment: {args.experiment_name}")
118 | train_losses.append(loss_meter.value()[0])
119 | loss_meter.reset()
120 |
121 | if (epoch + 1) % args.save_every == 0:
122 | save_path = f"{args.output_dir}/{args.experiment_name}/{epoch}/"
123 | pathlib.Path(save_path).mkdir(parents=True, exist_ok=True)
124 | torch.save(model.state_dict(), f"{save_path}/model_ckpt.torch")
125 |
126 | annealer.step()
127 |
128 | pathlib.Path(f"{args.output_dir}/{args.experiment_name}").mkdir(parents=True, exist_ok=True)
129 | save_path = f"{args.output_dir}/{args.experiment_name}/model_ckpt.torch"
130 | torch.save(model.state_dict(), save_path)
131 | args_path = f"{args.output_dir}/{args.experiment_name}/args.pkl"
132 | pickle.dump(args, open(args_path, "wb"))
133 | save_path = f"{args.output_dir}/{args.experiment_name}/losses_train.npy"
134 | np.save(save_path, np.array(train_losses))
135 |
136 |
--------------------------------------------------------------------------------
/src/verify.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from argparse import ArgumentParser
4 | from torch.utils.data import DataLoader
5 | from tqdm import tqdm
6 | from src.attacks import *
7 | from src.noises import *
8 | from src.models import *
9 | from src.datasets import get_dataset, get_num_labels
10 |
11 |
12 | if __name__ == "__main__":
13 |
14 | argparser = ArgumentParser()
15 | argparser.add_argument("--device", default="cuda", type=str)
16 | argparser.add_argument("--batch-size", default=4, type=int),
17 | argparser.add_argument("--num-workers", default=os.cpu_count(), type=int)
18 | argparser.add_argument("--sample-size-pred", default=64, type=int)
19 | argparser.add_argument("--noise-batch-size", default=512, type=int)
20 | argparser.add_argument("--sigma", default=0.0, type=float)
21 | argparser.add_argument("--noise", default="Clean", type=str)
22 | argparser.add_argument("--k", default=None, type=int)
23 | argparser.add_argument("--j", default=None, type=int)
24 | argparser.add_argument("--a", default=None, type=int)
25 | argparser.add_argument("--lambd", default=None, type=float)
26 | argparser.add_argument("--adv", default=2, type=int)
27 | argparser.add_argument("--experiment-name", default="cifar", type=str)
28 | argparser.add_argument("--dataset", default="cifar", type=str)
29 | argparser.add_argument("--model", default="WideResNet", type=str)
30 | argparser.add_argument("--output-dir", type=str, default=os.getenv("PT_OUTPUT_DIR"))
31 | args = argparser.parse_args()
32 |
33 | test_dataset = get_dataset(args.dataset, "test")
34 | test_loader = DataLoader(test_dataset, shuffle=False, batch_size=args.batch_size, # todo: fix
35 | num_workers=args.num_workers)
36 |
37 | save_path = f"{args.output_dir}/{args.experiment_name}/model_ckpt.torch"
38 | model = eval(args.model)(dataset=args.dataset, device=args.device)
39 | model.load_state_dict(torch.load(save_path))
40 | model.eval()
41 |
42 | noise = parse_noise_from_args(args, device=args.device, dim=get_dim(args.dataset))
43 |
44 | eps_range = (3.0, 2.0, 1.0, 0.5, 0.25)
45 |
46 | results = {f"preds_adv_{eps}": np.zeros((len(test_dataset), 10)) for eps in eps_range}
47 |
48 | for i, (x, y) in tqdm(enumerate(test_loader), total=len(test_loader)):
49 |
50 | x, y = x.to(args.device), y.to(args.device)
51 | lower, upper = i * args.batch_size, (i + 1) * args.batch_size
52 |
53 | for eps in eps_range:
54 | x_adv, _ = pgd_attack_smooth(model, x, y, eps=eps, noise=noise, sample_size=128,
55 | steps=20, p=args.adv, clamp=(0, 1))
56 | preds_adv = smooth_predict_hard(model, x_adv, noise, args.sample_size_pred,
57 | args.noise_batch_size)
58 | results[f"preds_adv_{eps}"][lower:upper, :] = preds_adv.probs.data.cpu().numpy()
59 | assert ((x - x_adv).reshape(x.shape[0], -1).norm(dim=1, p=args.adv) <= eps + 1e-2).all()
60 |
61 | save_path = f"{args.output_dir}/{args.experiment_name}"
62 | for k, v in results.items():
63 | np.save(f"{save_path}/{k}.npy", v)
64 |
65 |
--------------------------------------------------------------------------------
/svgs/1418cce7d60743be1c545cd950367159.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
--------------------------------------------------------------------------------
/svgs/1fa8048512f84790ef174f591d0cb851.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/svgs/336fefe2418749fabf50594e52f7b776.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/svgs/44c65658d6cd134b1599c29b31949f77.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/svgs/5dc1880e644c7b3a0e9fa954759762ea.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/svgs/8244067f9118b85361c6645cc9f1c526.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
--------------------------------------------------------------------------------
/svgs/839a0dc412c4f8670dd1064e0d6d412f.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/svgs/8d2d1eabb21bb41807292151fe468472.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/svgs/DistributionVenn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyduan/rs4a/a02831806ee4a117ef82ce5738fe8c5e8f180b06/svgs/DistributionVenn.png
--------------------------------------------------------------------------------
/svgs/b52b48d8661f69776e1b6650998d5067.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/svgs/bb9e6385ceb6d4a2d83a5b51a3c870c9.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/svgs/bd5b313d1d74ae2fc57ddb870603d84b.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/svgs/e1085464f81e12de4a74d54d14eb5dc5.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/svgs/e703845884313f30712bfc7262a5e65b.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/svgs/ec90b4fe342a37de851db6db2b08d4f4.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/svgs/envelopes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyduan/rs4a/a02831806ee4a117ef82ce5738fe8c5e8f180b06/svgs/envelopes.png
--------------------------------------------------------------------------------
/svgs/robust-radii.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyduan/rs4a/a02831806ee4a117ef82ce5738fe8c5e8f180b06/svgs/robust-radii.png
--------------------------------------------------------------------------------
/svgs/table.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tonyduan/rs4a/a02831806ee4a117ef82ce5738fe8c5e8f180b06/svgs/table.png
--------------------------------------------------------------------------------