├── .gitignore ├── LICENSE ├── README.md ├── baselines.py ├── cnns.py ├── data.py ├── dp_utils.py ├── log.py ├── models.py ├── requirements.txt ├── scripts ├── run_baselines_cifar10.py ├── run_cnns_cifar10.py └── run_cnns_cifar10_scat.py ├── tiny_images.py ├── train_utils.py └── transfer ├── __init__.py ├── extract_cifar100.py ├── extract_simclr.py ├── resnext.py └── transfer_cifar.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | figures/ 133 | .DS_STORE 134 | .idea/ 135 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 ftramer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Handcrafted-DP 2 | 3 | **This repository contains code to train differentially private models 4 | with handcrafted vision features.** 5 | 6 | These models are introduced and analyzed in: 7 | 8 | *Differentially Private Learning Needs Better Features (Or Much More Data)*
9 | **Florian Tramèr and Dan Boneh**
10 | [arXiv:2011.11660](http://arxiv.org/abs/2011.11660) 11 | 12 | ## Installation 13 | 14 | The main dependencies are [pytorch](https://github.com/pytorch/pytorch), 15 | [kymatio](https://github.com/kymatio/kymatio) 16 | and [opacus](https://github.com/pytorch/opacus). 17 | 18 | You can install all requirements with: 19 | ```bash 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | The code was tested with `python 3.7`, `torch 1.6` and `CUDA 10.1`. 24 | 25 | 26 | ## Example Usage and Results 27 | 28 | This table presents the main results from our paper. For each dataset, we target a privacy budget of `(epsilon=3, delta=1e-5)`. 29 | We compare three types of models: 30 | 1) Regular CNNs trained "end-to-end" from image pixels. 31 | 2) Linear models fine-tuned on top of "handcrafted" 32 | [ScatterNet](https://arxiv.org/abs/1412.8659) features. 33 | 3) Small CNNs fine-tuned on ScatterNet features. 34 | 35 | | Dataset | End-to-end CNN | ScatterNet + linear | ScatterNet + CNN | 36 | | ------------- | ------------- | ------------- | ------------- | 37 | | MNIST | 98.1% | **98.7%** | **98.7%** 38 | | Fashion-MNIST | 86.0% | **89.7%** | 89.0% 39 | | CIFAR-10 | 59.2% | 67.0% | **69.3%** 40 | 41 | 42 | ### Determining the Noise Multiplier 43 | 44 | The [DP-SGD](https://arxiv.org/abs/1607.00133) algorithm adds noise 45 | to every gradient update to preserve privacy. 46 | The "noise multiplier" is a parameter that determines the amount of noise 47 | that is added. 48 | The higher the noise multiplier, the stronger the privacy guarantees, 49 | but the harder it is to train accurate models. 50 | 51 | In our paper, we compute the noise multiplier so that our fixed privacy budget 52 | of `(epsilon=3, delta=1e-5)` is consumed after some fixed number of epochs. 53 | The noise multiplier can be computed as: 54 | ```python 55 | from dp_utils import get_noise_mul 56 | num_samples = 50000 57 | batch_size = 512 58 | target_epsilon = 3 59 | target_delta = 1e-5 60 | epochs = 40 61 | noise_mul = get_noise_mul(num_samples, batch_size, target_epsilon, epochs, target_delta=target_delta) 62 | ``` 63 | 64 | ### End-to-end CNNs 65 | 66 | To reproduce the results for end-to-end CNNs with the best hyper-parameters from our paper, run 67 | ```bash 68 | python3 cnns.py --dataset=mnist --batch_size=512 --lr=0.5 --noise_multiplier=1.23 69 | python3 cnns.py --dataset=fmnist --batch_size=2048 --lr=4 --noise_multiplier=2.15 70 | python3 cnns.py --dataset=cifar10 --batch_size=1024 --lr=1 --noise_multiplier=1.54 71 | ``` 72 | The noise multipliers are computed so as to consume the privacy budget in 73 | respectively `40`, `40` and `30` epochs. 74 | 75 | ### ScatterNet models 76 | 77 | To reproduce the results for linear ScatterNet models, run 78 | ```bash 79 | python3 baselines.py --dataset=mnist --batch_size=4096 --lr=8 --input_norm=BN --bn_noise_multiplier=8 --noise_multiplier=3.04 80 | python3 baselines.py --dataset=fmnist --batch_size=8192 --lr=16 --input_norm=GroupNorm --num_groups=27 --noise_multiplier=4.05 81 | python3 baselines.py --dataset=cifar10 --batch_size=8192 --lr=4 --input_norm=BN --bn_noise_multiplier=8 --noise_multiplier=5.67 82 | ``` 83 | And for CNNs fine-tuned on ScatterNet features, run: 84 | ```bash 85 | python3 cnns.py --dataset=mnist --use_scattering --batch_size=1024 --lr=1 --input_norm=BN --bn_noise_multiplier=8 --noise_multiplier=1.35 86 | python3 cnns.py --dataset=fmnist --use_scattering --batch_size=2048 --lr=4 --input_norm=GroupNorm --num_groups=27 --noise_multiplier=2.15 87 | python3 cnns.py --dataset=cifar10 --use_scattering --batch_size=8192 --lr=4 --input_norm=BN --bn_noise_multiplier=8 --noise_multiplier=5.67 88 | ``` 89 | 90 | There are a few additional parameters here: 91 | * The `input_norm` parameter determines how the ScatterNet features are normalized. 92 | We support Group Normalization (`input_norm=GN`) 93 | and (frozen) Batch Normalization (`input_norm=BN`). 94 | * When using Group Normalization, the `num_groups` parameter specifies the number 95 | of groups into which to split the features for normalization. 96 | * When using Batch Normalization, we first privately compute the mean and variance 97 | of the features across the entire training set. This requires adding noise to 98 | these statistics. The `bn_noise_multiplier` specifies the scale of the noise. 99 | 100 | When using Batch Normalization, we *compose* the privacy losses of the 101 | normalization step and of the DP-SGD algorithm. 102 | Specifically, we first compute the Rényi-DP budget for the normalization step, 103 | and then compute the `noise_multiplier` of the DP-SGD algorithm so that the total 104 | privacy budget is used after a fixed number of epochs: 105 | ```python 106 | from dp_utils import get_renyi_divergence, get_noise_mul 107 | rdp = 2 * get_renyi_divergence(1, bn_noise_multiplier) 108 | noise_mul = get_noise_mul(num_samples, batch_size, target_epsilon, epochs, rdp_init=rdp, target_delta=target_delta) 109 | ``` 110 | 111 | ### Measuring the Data Complexity of Private Learning 112 | 113 | To understand how expensive it currently is to exceed handcrafted features 114 | with private end-to-end deep learning, we compare the performance of the 115 | above models on increasingly large training sets. 116 | 117 | To obtain a larger dataset comparable to CIFAR-10, we use `500'000` additional 118 | pseudo-labelled tiny images collected by [Carmon et al.](https://github.com/yaircarmon/semisup-adv) 119 | 120 | To re-train the above models for `120` epochs on the full dataset of `550'000` images, use: 121 | 122 | ```bash 123 | python3 tiny_images.py --batch_size=8192 --lr=16 --delta=9.09e-7 --model=linear --use_scattering --bn_noise_multiplier=8 --epochs=120 --noise_multiplier=1.1 124 | python3 tiny_images.py --batch_size=8192 --lr=16 --delta=9.09e-7 --model=cnn --epochs=120 --noise_multiplier=1.1 125 | python3 tiny_images.py --batch_size=8192 --lr=16 --delta=9.09e-7 --model=cnn --use_scattering --bn_noise_multiplier=8 --epochs=120 --noise_multiplier=1.1 126 | ``` 127 | 128 | For a privacy budget of `(epsilon=3, delta=1/2N)`, where `N` is the size of the 129 | training data, we obtain the following improved test accuracies on CIFAR-10: 130 | 131 | N| End-to-end CNN | ScatterNet + linear | ScatterNet + CNN | 132 | | ------------- | ------------- | ------------- | ------------- | 133 | | 50K | 59.2% | 67.0% | **69.3%** 134 | |550K | **75.8%** | 70.7% | 74.5% | 135 | 136 | ### Private Transfer Learning 137 | Our paper also contains some results for private transfer learning to CIFAR-10. 138 | For a privacy budget of `(epsilon=2, delta=1e-5)` we get: 139 | 140 | Source Model | Transfer Accuracy on CIFAR-10 | 141 | | ------------- | ------------- | 142 | | ResNeXt-29 (CIFAR-100) | 79.6% | 143 | | SIMCLR v2 (unlabelled ImageNet) | 92.4% | 144 | 145 | These results can be reproduced as follows. 146 | First, you'll need to download the `resnext-8x64d` model from 147 | [here](https://github.com/bearpaw/pytorch-classification). 148 | 149 | Then, we extract features from the source models: 150 | ```bash 151 | python3 -m transfer.extract_cifar100 152 | python3 -m transfer.extract_simclr 153 | ``` 154 | This will create a `transfer/features` directory unless one already exists. 155 | 156 | Finally, we train linear models with DP-SGD on top of these features: 157 | ```bash 158 | python3 -m transfer.transfer_cifar --feature_path=transfer/features/cifar100_resnext --batch_size=2048 --lr=8 --noise_multiplier=3.32 159 | python3 -m transfer.transfer_cifar --feature_path=transfer/features/simclr_r50_2x_sk1 --batch_size=1024 --lr=4 --noise_multiplier=2.40 160 | ``` 161 | -------------------------------------------------------------------------------- /baselines.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | from opacus import PrivacyEngine 7 | from sklearn.linear_model import LogisticRegression 8 | 9 | from train_utils import get_device, train, test 10 | from data import get_data, get_scatter_transform, get_scattered_loader 11 | from models import ScatterLinear, get_num_params 12 | from dp_utils import ORDERS, get_privacy_spent, get_renyi_divergence, scatter_normalization 13 | from log import Logger 14 | 15 | 16 | def main(dataset, augment=False, batch_size=2048, mini_batch_size=256, sample_batches=False, 17 | lr=1, optim="SGD", momentum=0.9, nesterov=False, noise_multiplier=1, max_grad_norm=0.1, 18 | epochs=100, input_norm=None, num_groups=None, bn_noise_multiplier=None, 19 | max_epsilon=None, logdir=None): 20 | 21 | logger = Logger(logdir) 22 | device = get_device() 23 | 24 | train_data, test_data = get_data(dataset, augment=augment) 25 | scattering, K, (h, w) = get_scatter_transform(dataset) 26 | scattering.to(device) 27 | 28 | bs = batch_size 29 | assert bs % mini_batch_size == 0 30 | n_acc_steps = bs // mini_batch_size 31 | 32 | # Batch accumulation and data augmentation with Poisson sampling isn't implemented 33 | if sample_batches: 34 | assert n_acc_steps == 1 35 | assert not augment 36 | 37 | train_loader = torch.utils.data.DataLoader( 38 | train_data, batch_size=mini_batch_size, shuffle=True, num_workers=1, pin_memory=True) 39 | 40 | test_loader = torch.utils.data.DataLoader( 41 | test_data, batch_size=mini_batch_size, shuffle=False, num_workers=1, pin_memory=True) 42 | 43 | rdp_norm = 0 44 | if input_norm == "BN": 45 | # compute noisy data statistics or load from disk if pre-computed 46 | save_dir = f"bn_stats/{dataset}" 47 | os.makedirs(save_dir, exist_ok=True) 48 | bn_stats, rdp_norm = scatter_normalization(train_loader, 49 | scattering, 50 | K, 51 | device, 52 | len(train_data), 53 | len(train_data), 54 | noise_multiplier=bn_noise_multiplier, 55 | orders=ORDERS, 56 | save_dir=save_dir) 57 | model = ScatterLinear(K, (h, w), input_norm="BN", bn_stats=bn_stats) 58 | else: 59 | model = ScatterLinear(K, (h, w), input_norm=input_norm, num_groups=num_groups) 60 | 61 | model.to(device) 62 | 63 | if augment: 64 | model = nn.Sequential(scattering, model) 65 | train_loader = torch.utils.data.DataLoader( 66 | train_data, batch_size=mini_batch_size, shuffle=True, 67 | num_workers=1, pin_memory=True, drop_last=True) 68 | else: 69 | # if there is no data augmentation, pre-compute the scattering transform 70 | train_loader = get_scattered_loader(train_loader, scattering, device, 71 | drop_last=True, 72 | sample_batches=sample_batches) 73 | test_loader = get_scattered_loader(test_loader, scattering, device) 74 | 75 | # baseline Logistic Regression without privacy 76 | if optim == "LR": 77 | assert not augment 78 | X_train = [] 79 | y_train = [] 80 | X_test = [] 81 | y_test = [] 82 | for data, target in train_loader: 83 | with torch.no_grad(): 84 | data = data.to(device) 85 | X_train.append(data.cpu().numpy().reshape(len(data), -1)) 86 | y_train.extend(target.cpu().numpy()) 87 | 88 | for data, target in test_loader: 89 | with torch.no_grad(): 90 | data = data.to(device) 91 | X_test.append(data.cpu().numpy().reshape(len(data), -1)) 92 | y_test.extend(target.cpu().numpy()) 93 | 94 | import numpy as np 95 | X_train = np.concatenate(X_train, axis=0) 96 | X_test = np.concatenate(X_test, axis=0) 97 | y_train = np.asarray(y_train) 98 | y_test = np.asarray(y_test) 99 | 100 | print(X_train.shape, y_train.shape, X_test.shape, y_test.shape) 101 | 102 | for idx, C in enumerate([0.01, 0.1, 1.0, 10, 100]): 103 | clf = LogisticRegression(C=C, fit_intercept=True) 104 | clf.fit(X_train, y_train) 105 | 106 | train_acc = 100 * clf.score(X_train, y_train) 107 | test_acc = 100 * clf.score(X_test, y_test) 108 | print(f"C={C}, " 109 | f"Acc train = {train_acc: .2f}, " 110 | f"Acc test = {test_acc: .2f}") 111 | 112 | logger.log_epoch(idx, 0, train_acc, 0, test_acc, None) 113 | return 114 | 115 | print(f"model has {get_num_params(model)} parameters") 116 | 117 | if optim == "SGD": 118 | optimizer = torch.optim.SGD(model.parameters(), lr=lr, 119 | momentum=momentum, 120 | nesterov=nesterov) 121 | else: 122 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 123 | 124 | privacy_engine = PrivacyEngine( 125 | model, 126 | sample_rate=bs / len(train_data), 127 | alphas=ORDERS, 128 | noise_multiplier=noise_multiplier, 129 | max_grad_norm=max_grad_norm, 130 | ) 131 | privacy_engine.attach(optimizer) 132 | 133 | for epoch in range(0, epochs): 134 | print(f"\nEpoch: {epoch}") 135 | train_loss, train_acc = train(model, train_loader, optimizer, n_acc_steps=n_acc_steps) 136 | test_loss, test_acc = test(model, test_loader) 137 | 138 | if noise_multiplier > 0: 139 | rdp_sgd = get_renyi_divergence( 140 | privacy_engine.sample_rate, privacy_engine.noise_multiplier 141 | ) * privacy_engine.steps 142 | epsilon, _ = get_privacy_spent(rdp_norm + rdp_sgd) 143 | epsilon2, _ = get_privacy_spent(rdp_sgd) 144 | print(f"ε = {epsilon:.3f} (sgd only: ε = {epsilon2:.3f})") 145 | 146 | if max_epsilon is not None and epsilon >= max_epsilon: 147 | return 148 | else: 149 | epsilon = None 150 | 151 | logger.log_epoch(epoch, train_loss, train_acc, test_loss, test_acc, epsilon) 152 | logger.log_scalar("epsilon/train", epsilon, epoch) 153 | 154 | 155 | if __name__ == '__main__': 156 | parser = argparse.ArgumentParser() 157 | parser.add_argument('--dataset', choices=['cifar10', 'fmnist', 'mnist']) 158 | parser.add_argument('--augment', action="store_true") 159 | parser.add_argument('--batch_size', type=int, default=2048) 160 | parser.add_argument('--mini_batch_size', type=int, default=256) 161 | parser.add_argument('--lr', type=float, default=0.01) 162 | parser.add_argument('--optim', type=str, default="SGD", 163 | choices=["SGD", "Adam", "LR"]) 164 | parser.add_argument('--momentum', type=float, default=0.9) 165 | parser.add_argument('--nesterov', action="store_true") 166 | parser.add_argument('--noise_multiplier', type=float, default=1) 167 | parser.add_argument('--max_grad_norm', type=float, default=0.1) 168 | parser.add_argument('--epochs', type=int, default=100) 169 | parser.add_argument('--input_norm', default=None, 170 | choices=["GroupNorm", "BN"]) 171 | parser.add_argument('--num_groups', type=int, default=81) 172 | parser.add_argument('--bn_noise_multiplier', type=float, default=6) 173 | parser.add_argument('--max_epsilon', type=float, default=None) 174 | parser.add_argument('--sample_batches', action="store_true") 175 | parser.add_argument('--logdir', default=None) 176 | args = parser.parse_args() 177 | main(**vars(args)) 178 | -------------------------------------------------------------------------------- /cnns.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | from opacus import PrivacyEngine 7 | 8 | from train_utils import get_device, train, test 9 | from data import get_data, get_scatter_transform, get_scattered_loader 10 | from models import CNNS, get_num_params 11 | from dp_utils import ORDERS, get_privacy_spent, get_renyi_divergence, scatter_normalization 12 | from log import Logger 13 | 14 | 15 | def main(dataset, augment=False, use_scattering=False, size=None, 16 | batch_size=2048, mini_batch_size=256, sample_batches=False, 17 | lr=1, optim="SGD", momentum=0.9, nesterov=False, 18 | noise_multiplier=1, max_grad_norm=0.1, epochs=100, 19 | input_norm=None, num_groups=None, bn_noise_multiplier=None, 20 | max_epsilon=None, logdir=None, early_stop=True): 21 | 22 | logger = Logger(logdir) 23 | device = get_device() 24 | 25 | train_data, test_data = get_data(dataset, augment=augment) 26 | 27 | if use_scattering: 28 | scattering, K, _ = get_scatter_transform(dataset) 29 | scattering.to(device) 30 | else: 31 | scattering = None 32 | K = 3 if len(train_data.data.shape) == 4 else 1 33 | 34 | bs = batch_size 35 | assert bs % mini_batch_size == 0 36 | n_acc_steps = bs // mini_batch_size 37 | 38 | # Batch accumulation and data augmentation with Poisson sampling isn't implemented 39 | if sample_batches: 40 | assert n_acc_steps == 1 41 | assert not augment 42 | 43 | train_loader = torch.utils.data.DataLoader( 44 | train_data, batch_size=mini_batch_size, shuffle=True, num_workers=1, pin_memory=True) 45 | 46 | test_loader = torch.utils.data.DataLoader( 47 | test_data, batch_size=mini_batch_size, shuffle=False, num_workers=1, pin_memory=True) 48 | 49 | rdp_norm = 0 50 | if input_norm == "BN": 51 | # compute noisy data statistics or load from disk if pre-computed 52 | save_dir = f"bn_stats/{dataset}" 53 | os.makedirs(save_dir, exist_ok=True) 54 | bn_stats, rdp_norm = scatter_normalization(train_loader, 55 | scattering, 56 | K, 57 | device, 58 | len(train_data), 59 | len(train_data), 60 | noise_multiplier=bn_noise_multiplier, 61 | orders=ORDERS, 62 | save_dir=save_dir) 63 | model = CNNS[dataset](K, input_norm="BN", bn_stats=bn_stats, size=size) 64 | else: 65 | model = CNNS[dataset](K, input_norm=input_norm, num_groups=num_groups, size=size) 66 | 67 | model.to(device) 68 | 69 | if use_scattering and augment: 70 | model = nn.Sequential(scattering, model) 71 | train_loader = torch.utils.data.DataLoader( 72 | train_data, batch_size=mini_batch_size, shuffle=True, 73 | num_workers=1, pin_memory=True, drop_last=True) 74 | else: 75 | # pre-compute the scattering transform if necessery 76 | train_loader = get_scattered_loader(train_loader, scattering, device, 77 | drop_last=True, sample_batches=sample_batches) 78 | test_loader = get_scattered_loader(test_loader, scattering, device) 79 | 80 | print(f"model has {get_num_params(model)} parameters") 81 | 82 | if optim == "SGD": 83 | optimizer = torch.optim.SGD(model.parameters(), lr=lr, 84 | momentum=momentum, 85 | nesterov=nesterov) 86 | else: 87 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 88 | 89 | privacy_engine = PrivacyEngine( 90 | model, 91 | sample_rate=bs / len(train_data), 92 | alphas=ORDERS, 93 | noise_multiplier=noise_multiplier, 94 | max_grad_norm=max_grad_norm, 95 | ) 96 | privacy_engine.attach(optimizer) 97 | 98 | best_acc = 0 99 | flat_count = 0 100 | 101 | for epoch in range(0, epochs): 102 | print(f"\nEpoch: {epoch}") 103 | 104 | train_loss, train_acc = train(model, train_loader, optimizer, n_acc_steps=n_acc_steps) 105 | test_loss, test_acc = test(model, test_loader) 106 | 107 | if noise_multiplier > 0: 108 | rdp_sgd = get_renyi_divergence( 109 | privacy_engine.sample_rate, privacy_engine.noise_multiplier 110 | ) * privacy_engine.steps 111 | epsilon, _ = get_privacy_spent(rdp_norm + rdp_sgd) 112 | epsilon2, _ = get_privacy_spent(rdp_sgd) 113 | print(f"ε = {epsilon:.3f} (sgd only: ε = {epsilon2:.3f})") 114 | 115 | if max_epsilon is not None and epsilon >= max_epsilon: 116 | return 117 | else: 118 | epsilon = None 119 | 120 | logger.log_epoch(epoch, train_loss, train_acc, test_loss, test_acc, epsilon) 121 | logger.log_scalar("epsilon/train", epsilon, epoch) 122 | 123 | # stop if we're not making progress 124 | if test_acc > best_acc: 125 | best_acc = test_acc 126 | flat_count = 0 127 | else: 128 | flat_count += 1 129 | if flat_count >= 20 and early_stop: 130 | print("plateau...") 131 | return 132 | 133 | 134 | if __name__ == '__main__': 135 | parser = argparse.ArgumentParser() 136 | parser.add_argument('--dataset', choices=['cifar10', 'fmnist', 'mnist']) 137 | parser.add_argument('--size', default=None) 138 | parser.add_argument('--augment', action="store_true") 139 | parser.add_argument('--use_scattering', action="store_true") 140 | parser.add_argument('--batch_size', type=int, default=2048) 141 | parser.add_argument('--mini_batch_size', type=int, default=256) 142 | parser.add_argument('--lr', type=float, default=0.01) 143 | parser.add_argument('--optim', type=str, default="SGD", choices=["SGD", "Adam"]) 144 | parser.add_argument('--momentum', type=float, default=0.9) 145 | parser.add_argument('--nesterov', action="store_true") 146 | parser.add_argument('--noise_multiplier', type=float, default=1) 147 | parser.add_argument('--max_grad_norm', type=float, default=0.1) 148 | parser.add_argument('--epochs', type=int, default=100) 149 | parser.add_argument('--input_norm', default=None, choices=["GroupNorm", "BN"]) 150 | parser.add_argument('--num_groups', type=int, default=81) 151 | parser.add_argument('--bn_noise_multiplier', type=float, default=6) 152 | parser.add_argument('--max_epsilon', type=float, default=None) 153 | parser.add_argument('--early_stop', action='store_false') 154 | parser.add_argument('--sample_batches', action="store_true") 155 | parser.add_argument('--logdir', default=None) 156 | args = parser.parse_args() 157 | main(**vars(args)) 158 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | from kymatio.torch import Scattering2D 4 | import os 5 | import pickle 6 | import numpy as np 7 | import logging 8 | 9 | 10 | SHAPES = { 11 | "cifar10": (32, 32, 3), 12 | "cifar10_500K": (32, 32, 3), 13 | "fmnist": (28, 28, 1), 14 | "mnist": (28, 28, 1) 15 | } 16 | 17 | 18 | def get_scatter_transform(dataset): 19 | shape = SHAPES[dataset] 20 | scattering = Scattering2D(J=2, shape=shape[:2]) 21 | K = 81 * shape[2] 22 | (h, w) = shape[:2] 23 | return scattering, K, (h//4, w//4) 24 | 25 | 26 | def get_data(name, augment=False, **kwargs): 27 | if name == "cifar10": 28 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 29 | std=[0.229, 0.224, 0.225]) 30 | 31 | if augment: 32 | train_transforms = [ 33 | transforms.RandomHorizontalFlip(), 34 | transforms.RandomCrop(32, 4), 35 | transforms.ToTensor(), 36 | normalize, 37 | ] 38 | else: 39 | train_transforms = [ 40 | transforms.ToTensor(), 41 | normalize, 42 | ] 43 | 44 | train_set = datasets.CIFAR10(root=".data", train=True, 45 | transform=transforms.Compose(train_transforms), 46 | download=True) 47 | 48 | test_set = datasets.CIFAR10(root=".data", train=False, 49 | transform=transforms.Compose( 50 | [transforms.ToTensor(), normalize] 51 | )) 52 | 53 | elif name == "fmnist": 54 | train_set = datasets.FashionMNIST(root='.data', train=True, 55 | transform=transforms.ToTensor(), 56 | download=True) 57 | 58 | test_set = datasets.FashionMNIST(root='.data', train=False, 59 | transform=transforms.ToTensor(), 60 | download=True) 61 | 62 | elif name == "mnist": 63 | train_set = datasets.MNIST(root='.data', train=True, 64 | transform=transforms.ToTensor(), 65 | download=True) 66 | 67 | test_set = datasets.MNIST(root='.data', train=False, 68 | transform=transforms.ToTensor(), 69 | download=True) 70 | 71 | elif name == "cifar10_500K": 72 | 73 | # extended version of CIFAR-10 with pseudo-labelled tinyimages 74 | 75 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 76 | std=[0.229, 0.224, 0.225]) 77 | 78 | if augment: 79 | train_transforms = [ 80 | transforms.RandomHorizontalFlip(), 81 | transforms.RandomCrop(32, 4), 82 | transforms.ToTensor(), 83 | normalize, 84 | ] 85 | else: 86 | train_transforms = [ 87 | transforms.ToTensor(), 88 | normalize, 89 | ] 90 | 91 | train_set = SemiSupervisedDataset(kwargs['aux_data_filename'], 92 | root=".data", 93 | train=True, 94 | download=True, 95 | transform=transforms.Compose(train_transforms)) 96 | test_set = None 97 | else: 98 | raise ValueError(f"unknown dataset {name}") 99 | 100 | return train_set, test_set 101 | 102 | 103 | class SemiSupervisedDataset(torch.utils.data.Dataset): 104 | def __init__(self, 105 | aux_data_filename=None, 106 | train=False, 107 | **kwargs): 108 | """A dataset with auxiliary pseudo-labeled data""" 109 | 110 | self.dataset = datasets.CIFAR10(train=train, **kwargs) 111 | self.train = train 112 | 113 | # shuffle cifar-10 114 | p = np.random.permutation(len(self.data)) 115 | self.data = self.data[p] 116 | self.targets = list(np.asarray(self.targets)[p]) 117 | 118 | if self.train: 119 | self.sup_indices = list(range(len(self.targets))) 120 | self.unsup_indices = [] 121 | 122 | aux_path = os.path.join(kwargs['root'], aux_data_filename) 123 | print("Loading data from %s" % aux_path) 124 | with open(aux_path, 'rb') as f: 125 | aux = pickle.load(f) 126 | aux_data = aux['data'] 127 | aux_targets = aux['extrapolated_targets'] 128 | orig_len = len(self.data) 129 | 130 | # shuffle additional data 131 | p = np.random.permutation(len(aux_data)) 132 | aux_data = aux_data[p] 133 | aux_targets = aux_targets[p] 134 | 135 | self.data = np.concatenate((self.data, aux_data), axis=0) 136 | self.targets.extend(aux_targets) 137 | 138 | # note that we use unsup indices to track the labeled datapoints 139 | # whose labels are "fake" 140 | self.unsup_indices.extend( 141 | range(orig_len, orig_len+len(aux_data))) 142 | 143 | logger = logging.getLogger() 144 | logger.info("Training set") 145 | logger.info("Number of training samples: %d", len(self.targets)) 146 | logger.info("Number of supervised samples: %d", 147 | len(self.sup_indices)) 148 | logger.info("Number of unsup samples: %d", len(self.unsup_indices)) 149 | logger.info("Label (and pseudo-label) histogram: %s", 150 | tuple( 151 | zip(*np.unique(self.targets, return_counts=True)))) 152 | logger.info("Shape of training data: %s", np.shape(self.data)) 153 | 154 | # Test set 155 | else: 156 | self.sup_indices = list(range(len(self.targets))) 157 | self.unsup_indices = [] 158 | 159 | logger = logging.getLogger() 160 | logger.info("Test set") 161 | logger.info("Number of samples: %d", len(self.targets)) 162 | logger.info("Label histogram: %s", 163 | tuple( 164 | zip(*np.unique(self.targets, return_counts=True)))) 165 | logger.info("Shape of data: %s", np.shape(self.data)) 166 | 167 | @property 168 | def data(self): 169 | return self.dataset.data 170 | 171 | @data.setter 172 | def data(self, value): 173 | self.dataset.data = value 174 | 175 | @property 176 | def targets(self): 177 | return self.dataset.targets 178 | 179 | @targets.setter 180 | def targets(self, value): 181 | self.dataset.targets = value 182 | 183 | def __len__(self): 184 | return len(self.dataset) 185 | 186 | def __getitem__(self, item): 187 | self.dataset.labels = self.targets # because torchvision is annoying 188 | return self.dataset[item] 189 | 190 | def __repr__(self): 191 | fmt_str = 'Semisupervised Dataset ' + self.__class__.__name__ + '\n' 192 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 193 | fmt_str += ' Training: {}\n'.format(self.train) 194 | fmt_str += ' Root Location: {}\n'.format(self.dataset.root) 195 | tmp = ' Transforms (if any): ' 196 | fmt_str += '{0}{1}\n'.format(tmp, self.dataset.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 197 | tmp = ' Target Transforms (if any): ' 198 | fmt_str += '{0}{1}'.format(tmp, self.dataset.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 199 | return fmt_str 200 | 201 | 202 | class SemiSupervisedSampler(torch.utils.data.Sampler): 203 | def __init__(self, num_examples, num_batches, batch_size): 204 | self.inds = list(range(num_examples)) 205 | self.batch_size = batch_size 206 | self.num_batches = num_batches 207 | super().__init__(None) 208 | 209 | def __iter__(self): 210 | batch_counter = 0 211 | inds_shuffled = [self.inds[i] for i in torch.randperm(len(self.inds))] 212 | 213 | while len(inds_shuffled) < self.num_batches*self.batch_size: 214 | temp = [self.inds[i] for i in torch.randperm(len(self.inds))] 215 | inds_shuffled.extend(temp) 216 | 217 | for k in range(0, self.num_batches*self.batch_size, self.batch_size): 218 | if batch_counter == self.num_batches: 219 | break 220 | 221 | batch = inds_shuffled[k:(k + self.batch_size)] 222 | 223 | # this shuffle operation is very important, without it 224 | # batch-norm / DataParallel hell ensues 225 | np.random.shuffle(batch) 226 | yield batch 227 | batch_counter += 1 228 | 229 | def __len__(self): 230 | return self.num_batches 231 | 232 | 233 | class PoissonSampler(torch.utils.data.Sampler): 234 | def __init__(self, num_examples, batch_size): 235 | self.inds = np.arange(num_examples) 236 | self.batch_size = batch_size 237 | self.num_batches = int(np.ceil(num_examples / batch_size)) 238 | self.sample_rate = self.batch_size / (1.0 * num_examples) 239 | super().__init__(None) 240 | 241 | def __iter__(self): 242 | # select each data point independently with probability `sample_rate` 243 | for i in range(self.num_batches): 244 | batch_idxs = np.random.binomial(n=1, p=self.sample_rate, size=len(self.inds)) 245 | batch = self.inds[batch_idxs.astype(np.bool)] 246 | np.random.shuffle(batch) 247 | yield batch 248 | 249 | def __len__(self): 250 | return self.num_batches 251 | 252 | 253 | def get_scattered_dataset(loader, scattering, device, data_size): 254 | # pre-compute a scattering transform (if there is one) and return 255 | # a TensorDataset 256 | 257 | scatters = [] 258 | targets = [] 259 | 260 | num = 0 261 | for (data, target) in loader: 262 | data, target = data.to(device), target.to(device) 263 | if scattering is not None: 264 | data = scattering(data) 265 | scatters.append(data) 266 | targets.append(target) 267 | 268 | num += len(data) 269 | if num > data_size: 270 | break 271 | 272 | scatters = torch.cat(scatters, axis=0) 273 | targets = torch.cat(targets, axis=0) 274 | 275 | scatters = scatters[:data_size] 276 | targets = targets[:data_size] 277 | 278 | data = torch.utils.data.TensorDataset(scatters, targets) 279 | return data 280 | 281 | 282 | def get_scattered_loader(loader, scattering, device, drop_last=False, sample_batches=False): 283 | # pre-compute a scattering transform (if there is one) and return 284 | # a DataLoader 285 | 286 | scatters = [] 287 | targets = [] 288 | 289 | for (data, target) in loader: 290 | data, target = data.to(device), target.to(device) 291 | if scattering is not None: 292 | data = scattering(data) 293 | scatters.append(data) 294 | targets.append(target) 295 | 296 | scatters = torch.cat(scatters, axis=0) 297 | targets = torch.cat(targets, axis=0) 298 | 299 | data = torch.utils.data.TensorDataset(scatters, targets) 300 | 301 | if sample_batches: 302 | sampler = PoissonSampler(len(scatters), loader.batch_size) 303 | return torch.utils.data.DataLoader(data, batch_sampler=sampler, 304 | num_workers=0, pin_memory=False) 305 | else: 306 | shuffle = isinstance(loader.sampler, torch.utils.data.RandomSampler) 307 | return torch.utils.data.DataLoader(data, 308 | batch_size=loader.batch_size, 309 | shuffle=shuffle, 310 | num_workers=0, 311 | pin_memory=False, 312 | drop_last=drop_last) 313 | -------------------------------------------------------------------------------- /dp_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | import opacus.privacy_analysis as tf_privacy 7 | 8 | ORDERS = [1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)) 9 | 10 | 11 | def get_renyi_divergence(sample_rate, noise_multiplier, orders=ORDERS): 12 | rdp = torch.tensor( 13 | tf_privacy.compute_rdp( 14 | sample_rate, noise_multiplier, 1, orders 15 | ) 16 | ) 17 | return rdp 18 | 19 | 20 | def get_privacy_spent(total_rdp, target_delta=1e-5, orders=ORDERS): 21 | return tf_privacy.get_privacy_spent(orders, total_rdp, target_delta) 22 | 23 | 24 | def get_epsilon(sample_rate, mul, num_steps, target_delta=1e-5, orders=ORDERS, rdp_init=0): 25 | # compute the epsilon budget spent after `num_steps` with batch sampling rate 26 | # of `sample_rate` and a noise multiplier of `mul` 27 | 28 | rdp = rdp_init + get_renyi_divergence(sample_rate, mul, orders=orders) * num_steps 29 | eps, _ = get_privacy_spent(rdp, target_delta=target_delta, orders=orders) 30 | return eps 31 | 32 | 33 | def get_noise_mul(num_samples, batch_size, target_epsilon, epochs, rdp_init=0, target_delta=1e-5, orders=ORDERS): 34 | # compute the noise multiplier that results in a privacy budget 35 | # of `target_epsilon` being spent after a given number of epochs of DP-SGD. 36 | 37 | mul_low = 100 38 | mul_high = 0.1 39 | 40 | num_steps = math.floor(num_samples // batch_size) * epochs 41 | sample_rate = batch_size / (1.0 * num_samples) 42 | 43 | eps_low = get_epsilon(sample_rate, mul_low, num_steps, target_delta, orders, rdp_init=rdp_init) 44 | eps_high = get_epsilon(sample_rate, mul_high, num_steps, target_delta, orders, rdp_init=rdp_init) 45 | 46 | assert eps_low < target_epsilon 47 | assert eps_high > target_epsilon 48 | 49 | while eps_high - eps_low > 0.01: 50 | mul_mid = (mul_high + mul_low) / 2 51 | eps_mid = get_epsilon(sample_rate, mul_mid, num_steps, target_delta, orders, rdp_init=rdp_init) 52 | 53 | if eps_mid <= target_epsilon: 54 | mul_low = mul_mid 55 | eps_low = eps_mid 56 | else: 57 | mul_high = mul_mid 58 | eps_high = eps_mid 59 | 60 | return mul_low 61 | 62 | 63 | def get_noise_mul_privbyiter(num_samples, batch_size, target_epsilon, epochs, target_delta=1e-5): 64 | mul_low = 100 65 | mul_high = 0.1 66 | 67 | eps_low = priv_by_iter_guarantees(epochs, batch_size, num_samples, mul_low, target_delta, verbose=False) 68 | eps_high = priv_by_iter_guarantees(epochs, batch_size, num_samples, mul_high, target_delta, verbose=False) 69 | 70 | assert eps_low < target_epsilon 71 | assert eps_high > target_epsilon 72 | 73 | while eps_high - eps_low > 0.01: 74 | mul_mid = (mul_high + mul_low) / 2 75 | eps_mid = priv_by_iter_guarantees(epochs, batch_size, num_samples, mul_mid, target_delta, verbose=False) 76 | 77 | if eps_mid <= target_epsilon: 78 | mul_low = mul_mid 79 | eps_low = eps_mid 80 | else: 81 | mul_high = mul_mid 82 | eps_high = eps_mid 83 | 84 | return mul_low 85 | 86 | 87 | def scatter_normalization(train_loader, scattering, K, device, 88 | data_size, sample_size, 89 | noise_multiplier=1.0, orders=ORDERS, save_dir=None): 90 | # privately compute the mean and variance of scatternet features to normalize 91 | # the data. 92 | 93 | rdp = 0 94 | epsilon_norm = np.inf 95 | if noise_multiplier > 0: 96 | # compute the RDP spent in this step 97 | sample_rate = sample_size / (1.0 * data_size) 98 | rdp = 2*get_renyi_divergence(sample_rate, noise_multiplier, orders) 99 | epsilon_norm, _ = get_privacy_spent(rdp) 100 | 101 | # try loading pre-computed stats 102 | use_scattering = scattering is not None 103 | assert use_scattering 104 | mean_path = os.path.join(save_dir, f"mean_bn_{sample_size}_{noise_multiplier}_{use_scattering}.npy") 105 | var_path = os.path.join(save_dir, f"var_bn_{sample_size}_{noise_multiplier}_{use_scattering}.npy") 106 | 107 | print(f"Using BN stats for {sample_size}/{data_size} samples") 108 | print(f"With noise_mul={noise_multiplier}, we get ε_norm = {epsilon_norm:.3f}") 109 | 110 | try: 111 | print(f"loading {mean_path}") 112 | mean = np.load(mean_path) 113 | var = np.load(var_path) 114 | print(mean.shape, var.shape) 115 | except OSError: 116 | 117 | # compute the scattering transform and the mean and squared mean of features 118 | scatters = [] 119 | mean = 0 120 | sq_mean = 0 121 | count = 0 122 | for idx, (data, target) in enumerate(train_loader): 123 | with torch.no_grad(): 124 | data = data.to(device) 125 | if scattering is not None: 126 | data = scattering(data).view(-1, K, data.shape[2]//4, data.shape[3]//4) 127 | if noise_multiplier == 0: 128 | data = data.reshape(len(data), K, -1).mean(-1) 129 | mean += data.sum(0).cpu().numpy() 130 | sq_mean += (data**2).sum(0).cpu().numpy() 131 | else: 132 | scatters.append(data.cpu().numpy()) 133 | 134 | count += len(data) 135 | if count >= sample_size: 136 | break 137 | 138 | if noise_multiplier > 0: 139 | scatters = np.concatenate(scatters, axis=0) 140 | scatters = np.transpose(scatters, (0, 2, 3, 1)) 141 | 142 | scatters = scatters[:sample_size] 143 | 144 | # s x K 145 | scatter_means = np.mean(scatters.reshape(len(scatters), -1, K), axis=1) 146 | norms = np.linalg.norm(scatter_means, axis=-1) 147 | 148 | # technically a small privacy leak, sue me... 149 | thresh_mean = np.quantile(norms, 0.5) 150 | scatter_means /= np.maximum(norms / thresh_mean, 1).reshape(-1, 1) 151 | mean = np.mean(scatter_means, axis=0) 152 | 153 | mean += np.random.normal(scale=thresh_mean * noise_multiplier, 154 | size=mean.shape) / sample_size 155 | 156 | # s x K 157 | scatter_sq_means = np.mean((scatters ** 2).reshape(len(scatters), -1, K), 158 | axis=1) 159 | norms = np.linalg.norm(scatter_sq_means, axis=-1) 160 | 161 | # technically a small privacy leak, sue me... 162 | thresh_var = np.quantile(norms, 0.5) 163 | print(f"thresh_mean={thresh_mean:.2f}, thresh_var={thresh_var:.2f}") 164 | scatter_sq_means /= np.maximum(norms / thresh_var, 1).reshape(-1, 1) 165 | sq_mean = np.mean(scatter_sq_means, axis=0) 166 | sq_mean += np.random.normal(scale=thresh_var * noise_multiplier, 167 | size=sq_mean.shape) / sample_size 168 | var = np.maximum(sq_mean - mean ** 2, 0) 169 | else: 170 | mean /= count 171 | sq_mean /= count 172 | var = np.maximum(sq_mean - mean ** 2, 0) 173 | 174 | if save_dir is not None: 175 | print(f"saving mean and var: {mean.shape} {var.shape}") 176 | np.save(mean_path, mean) 177 | np.save(var_path, var) 178 | 179 | mean = torch.from_numpy(mean).to(device) 180 | var = torch.from_numpy(var).to(device) 181 | 182 | return (mean, var), rdp 183 | 184 | 185 | def priv_by_iter_guarantees(epochs, batch_size, samples, noise_multiplier, delta=1e-5, verbose=True): 186 | """Tabulating position-dependent privacy guarantees.""" 187 | if noise_multiplier == 0: 188 | if verbose: 189 | print('No differential privacy (additive noise is 0).') 190 | return np.inf 191 | 192 | if verbose: 193 | print('In the conditions of Theorem 34 (https://arxiv.org/abs/1808.06651) ' 194 | 'the training procedure results in the following privacy guarantees.') 195 | print('Out of the total of {} samples:'.format(samples)) 196 | 197 | steps_per_epoch = samples // batch_size 198 | orders = np.concatenate([np.linspace(2, 20, num=181), np.linspace(20, 100, num=81)]) 199 | for p in (.5, .9, .99): 200 | steps = math.ceil(steps_per_epoch * p) # Steps in the last epoch. 201 | coef = 2 * (noise_multiplier)**-2 * ( 202 | # Accounting for privacy loss 203 | (epochs - 1) / steps_per_epoch + # ... from all-but-last epochs 204 | 1 / (steps_per_epoch - steps + 1)) # ... due to the last epoch 205 | # Using RDP accountant to compute eps. Doing computation analytically is 206 | # an option. 207 | rdp = [order * coef for order in orders] 208 | eps, _ = get_privacy_spent(rdp, delta, orders) 209 | if verbose: 210 | print('\t{:g}% enjoy at least ({:.2f}, {})-DP'.format( 211 | p * 100, eps, delta)) 212 | 213 | return eps 214 | -------------------------------------------------------------------------------- /log.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import shutil 4 | import sys 5 | from torch.utils.tensorboard import SummaryWriter 6 | import torch 7 | 8 | 9 | def model_input(data, device): 10 | datum = data.data[0:1] 11 | if isinstance(datum, np.ndarray): 12 | return torch.from_numpy(datum).float().to(device) 13 | else: 14 | return datum.float().to(device) 15 | 16 | 17 | def get_script(): 18 | py_script = os.path.basename(sys.argv[0]) 19 | return os.path.splitext(py_script)[0] 20 | 21 | 22 | def get_specified_params(hparams): 23 | keys = [k.split("=")[0][2:] for k in sys.argv[1:]] 24 | specified = {k: hparams[k] for k in keys} 25 | return specified 26 | 27 | 28 | def make_hparam_str(hparams, exclude): 29 | return ",".join([f"{key}_{value}" 30 | for key, value in sorted(hparams.items()) 31 | if key not in exclude]) 32 | 33 | 34 | class Logger(object): 35 | def __init__(self, logdir): 36 | 37 | if logdir is None: 38 | self.writer = None 39 | else: 40 | if os.path.exists(logdir) and os.path.isdir(logdir): 41 | shutil.rmtree(logdir) 42 | 43 | self.writer = SummaryWriter(log_dir=logdir) 44 | 45 | def log_model(self, model, input_to_model): 46 | if self.writer is None: 47 | return 48 | self.writer.add_graph(model, input_to_model) 49 | 50 | def log_epoch(self, epoch, train_loss, train_acc, test_loss, test_acc, epsilon=None): 51 | if self.writer is None: 52 | return 53 | self.writer.add_scalar("Loss/train", train_loss, epoch) 54 | self.writer.add_scalar("Loss/test", test_loss, epoch) 55 | self.writer.add_scalar("Accuracy/train", train_acc, epoch) 56 | self.writer.add_scalar("Accuracy/test", test_acc, epoch) 57 | 58 | if epsilon is not None: 59 | self.writer.add_scalar("Acc@Eps/train", train_acc, 100*epsilon) 60 | self.writer.add_scalar("Acc@Eps/test", test_acc, 100*epsilon) 61 | 62 | def log_scalar(self, tag, scalar_value, global_step): 63 | if self.writer is None or scalar_value is None: 64 | return 65 | self.writer.add_scalar(tag, scalar_value, global_step) 66 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def standardize(x, bn_stats): 6 | if bn_stats is None: 7 | return x 8 | 9 | bn_mean, bn_var = bn_stats 10 | 11 | view = [1] * len(x.shape) 12 | view[1] = -1 13 | x = (x - bn_mean.view(view)) / torch.sqrt(bn_var.view(view) + 1e-5) 14 | 15 | # if variance is too low, just ignore 16 | x *= (bn_var.view(view) != 0).float() 17 | return x 18 | 19 | 20 | def clip_data(data, max_norm): 21 | norms = torch.norm(data.reshape(data.shape[0], -1), dim=-1) 22 | scale = (max_norm / norms).clamp(max=1.0) 23 | data *= scale.reshape(-1, 1, 1, 1) 24 | return data 25 | 26 | 27 | def get_num_params(model): 28 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 29 | 30 | 31 | class StandardizeLayer(nn.Module): 32 | def __init__(self, bn_stats): 33 | super(StandardizeLayer, self).__init__() 34 | self.bn_stats = bn_stats 35 | 36 | def forward(self, x): 37 | return standardize(x, self.bn_stats) 38 | 39 | 40 | class ClipLayer(nn.Module): 41 | def __init__(self, max_norm): 42 | super(ClipLayer, self).__init__() 43 | self.max_norm = max_norm 44 | 45 | def forward(self, x): 46 | return clip_data(x, self.max_norm) 47 | 48 | 49 | class CIFAR10_CNN(nn.Module): 50 | def __init__(self, in_channels=3, input_norm=None, **kwargs): 51 | super(CIFAR10_CNN, self).__init__() 52 | self.in_channels = in_channels 53 | self.features = None 54 | self.classifier = None 55 | self.norm = None 56 | 57 | self.build(input_norm, **kwargs) 58 | 59 | def build(self, input_norm=None, num_groups=None, 60 | bn_stats=None, size=None): 61 | 62 | if self.in_channels == 3: 63 | if size == "small": 64 | cfg = [16, 16, 'M', 32, 32, 'M', 64, 'M'] 65 | else: 66 | cfg = [32, 32, 'M', 64, 64, 'M', 128, 128, 'M'] 67 | 68 | self.norm = nn.Identity() 69 | else: 70 | if size == "small": 71 | cfg = [16, 16, 'M', 32, 32] 72 | else: 73 | cfg = [64, 'M', 64] 74 | if input_norm is None: 75 | self.norm = nn.Identity() 76 | elif input_norm == "GroupNorm": 77 | self.norm = nn.GroupNorm(num_groups, self.in_channels, affine=False) 78 | else: 79 | self.norm = lambda x: standardize(x, bn_stats) 80 | 81 | layers = [] 82 | act = nn.Tanh 83 | 84 | c = self.in_channels 85 | for v in cfg: 86 | if v == 'M': 87 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 88 | else: 89 | conv2d = nn.Conv2d(c, v, kernel_size=3, stride=1, padding=1) 90 | 91 | layers += [conv2d, act()] 92 | c = v 93 | 94 | self.features = nn.Sequential(*layers) 95 | 96 | if self.in_channels == 3: 97 | hidden = 128 98 | self.classifier = nn.Sequential(nn.Linear(c * 4 * 4, hidden), act(), nn.Linear(hidden, 10)) 99 | else: 100 | self.classifier = nn.Linear(c * 4 * 4, 10) 101 | 102 | def forward(self, x): 103 | if self.in_channels != 3: 104 | x = self.norm(x.view(-1, self.in_channels, 8, 8)) 105 | x = self.features(x) 106 | x = x.view(x.size(0), -1) 107 | x = self.classifier(x) 108 | return x 109 | 110 | 111 | class MNIST_CNN(nn.Module): 112 | def __init__(self, in_channels=1, input_norm=None, **kwargs): 113 | super(MNIST_CNN, self).__init__() 114 | self.in_channels = in_channels 115 | self.features = None 116 | self.classifier = None 117 | self.norm = None 118 | 119 | self.build(input_norm, **kwargs) 120 | 121 | def build(self, input_norm=None, num_groups=None, 122 | bn_stats=None, size=None): 123 | if self.in_channels == 1: 124 | ch1, ch2 = (16, 32) if size is None else (32, 64) 125 | cfg = [(ch1, 8, 2, 2), 'M', (ch2, 4, 2, 0), 'M'] 126 | self.norm = nn.Identity() 127 | else: 128 | ch1, ch2 = (16, 32) if size is None else (32, 64) 129 | cfg = [(ch1, 3, 2, 1), (ch2, 3, 1, 1)] 130 | if input_norm == "GroupNorm": 131 | self.norm = nn.GroupNorm(num_groups, self.in_channels, affine=False) 132 | elif input_norm == "BN": 133 | self.norm = lambda x: standardize(x, bn_stats) 134 | else: 135 | self.norm = nn.Identity() 136 | 137 | layers = [] 138 | 139 | c = self.in_channels 140 | for v in cfg: 141 | if v == 'M': 142 | layers += [nn.MaxPool2d(kernel_size=2, stride=1)] 143 | else: 144 | filters, k_size, stride, pad = v 145 | conv2d = nn.Conv2d(c, filters, kernel_size=k_size, stride=stride, padding=pad) 146 | 147 | layers += [conv2d, nn.Tanh()] 148 | c = filters 149 | 150 | self.features = nn.Sequential(*layers) 151 | 152 | hidden = 32 153 | self.classifier = nn.Sequential(nn.Linear(c * 4 * 4, hidden), 154 | nn.Tanh(), 155 | nn.Linear(hidden, 10)) 156 | 157 | def forward(self, x): 158 | if self.in_channels != 1: 159 | x = self.norm(x.view(-1, self.in_channels, 7, 7)) 160 | x = self.features(x) 161 | x = x.view(x.size(0), -1) 162 | x = self.classifier(x) 163 | return x 164 | 165 | 166 | class ScatterLinear(nn.Module): 167 | def __init__(self, in_channels, hw_dims, input_norm=None, classes=10, clip_norm=None, **kwargs): 168 | super(ScatterLinear, self).__init__() 169 | self.K = in_channels 170 | self.h = hw_dims[0] 171 | self.w = hw_dims[1] 172 | self.fc = None 173 | self.norm = None 174 | self.clip = None 175 | self.build(input_norm, classes=classes, clip_norm=clip_norm, **kwargs) 176 | 177 | def build(self, input_norm=None, num_groups=None, bn_stats=None, clip_norm=None, classes=10): 178 | self.fc = nn.Linear(self.K * self.h * self.w, classes) 179 | 180 | if input_norm is None: 181 | self.norm = nn.Identity() 182 | elif input_norm == "GroupNorm": 183 | self.norm = nn.GroupNorm(num_groups, self.K, affine=False) 184 | else: 185 | self.norm = lambda x: standardize(x, bn_stats) 186 | 187 | if clip_norm is None: 188 | self.clip = nn.Identity() 189 | else: 190 | self.clip = ClipLayer(clip_norm) 191 | 192 | def forward(self, x): 193 | x = self.norm(x.view(-1, self.K, self.h, self.w)) 194 | x = self.clip(x) 195 | x = x.reshape(x.size(0), -1) 196 | x = self.fc(x) 197 | return x 198 | 199 | 200 | CNNS = { 201 | "cifar10": CIFAR10_CNN, 202 | "fmnist": MNIST_CNN, 203 | "mnist": MNIST_CNN, 204 | } 205 | 206 | 207 | 208 | 209 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.11.0 2 | appdirs==1.4.4 3 | astunparse==1.6.3 4 | attrs==20.1.0 5 | bleach==3.1.5 6 | cachetools==4.1.1 7 | certifi==2020.6.20 8 | cffi==1.14.3 9 | chardet==3.0.4 10 | colorama==0.4.3 11 | configparser==5.0.1 12 | cryptography==3.2.1 13 | cycler==0.10.0 14 | dataclasses==0.6 15 | docutils==0.16 16 | future==0.18.2 17 | gast==0.3.3 18 | google-auth==1.23.0 19 | google-auth-oauthlib==0.4.2 20 | google-pasta==0.2.0 21 | grpcio==1.33.2 22 | h5py==2.10.0 23 | idna==2.10 24 | importlib-metadata==1.7.0 25 | iniconfig==1.0.1 26 | jeepney==0.5.0 27 | joblib==0.17.0 28 | Keras-Preprocessing==1.1.2 29 | keyring==21.4.0 30 | kiwisolver==1.3.1 31 | kymatio==0.2.0 32 | Markdown==3.3.3 33 | matplotlib==3.3.3 34 | more-itertools==8.5.0 35 | numpy==1.19.1 36 | oauthlib==3.1.0 37 | opacus==0.13.0 38 | opt-einsum==3.3.0 39 | packaging==20.4 40 | Pillow==7.2.0 41 | pkginfo==1.5.0.1 42 | pluggy==0.13.1 43 | protobuf==3.13.0 44 | py==1.9.0 45 | pyasn1==0.4.8 46 | pyasn1-modules==0.2.8 47 | pycparser==2.20 48 | Pygments==2.6.1 49 | pyparsing==2.4.7 50 | pytest==6.0.1 51 | python-dateutil==2.8.1 52 | readme-renderer==26.0 53 | requests==2.25.1 54 | requests-oauthlib==1.3.0 55 | requests-toolbelt==0.9.1 56 | rfc3986==1.4.0 57 | rsa==4.6 58 | scikit-learn==0.23.2 59 | scipy==1.5.2 60 | SecretStorage==3.2.0 61 | six==1.15.0 62 | tensorboard==2.3.0 63 | tensorboard-plugin-wit==1.7.0 64 | tensorflow==2.3.1 65 | tensorflow-estimator==2.3.0 66 | tensorflow-hub==0.10.0 67 | termcolor==1.1.0 68 | threadpoolctl==2.1.0 69 | toml==0.10.1 70 | torch==1.6.0 71 | torchcsprng==0.1.1 72 | torchvision==0.7.0 73 | tqdm==4.48.2 74 | twine==3.2.0 75 | typing-extensions==3.7.4.3 76 | urllib3==1.25.10 77 | webencodings==0.5.1 78 | Werkzeug==1.0.1 79 | wrapt==1.12.1 80 | zipp==3.1.0 81 | -------------------------------------------------------------------------------- /scripts/run_baselines_cifar10.py: -------------------------------------------------------------------------------- 1 | 2 | from baselines import main 3 | from dp_utils import get_noise_mul, get_renyi_divergence 4 | 5 | MAX_GRAD_NORM = 0.1 6 | MAX_EPS = 5 7 | 8 | BATCH_SIZES = [512, 1024, 2048, 4096, 8192, 16384] 9 | BASE_LRS = [0.125, 0.25, 0.5, 1.0] 10 | 11 | TARGET_EPS = 3 12 | TARGET_EPOCHS = [30, 60, 120] 13 | 14 | BN_MULS = [6, 8] 15 | GROUPS = [9, 27, 81] 16 | 17 | for target_epoch in TARGET_EPOCHS: 18 | for bs in BATCH_SIZES: 19 | for bn_mul in BN_MULS: 20 | rdp_norm = 2 * get_renyi_divergence(1.0, bn_mul) 21 | mul = get_noise_mul(50000, bs, TARGET_EPS, target_epoch, rdp_init=rdp_norm) 22 | 23 | for base_lr in BASE_LRS: 24 | lr = (bs // 512) * base_lr 25 | 26 | print(f"epoch={target_epoch}, bs={bs}, bn_mul={bn_mul}, lr={base_lr}*{bs//512}={lr}, mul={mul}") 27 | logdir = f"logs/baselines/cifar10/bs={bs}_lr={lr}_mul={mul:.2f}_bn={bn_mul}" 28 | 29 | main(dataset="cifar10", max_grad_norm=MAX_GRAD_NORM, 30 | lr=lr, batch_size=bs, noise_multiplier=mul, 31 | input_norm="BN", bn_noise_multiplier=bn_mul, 32 | max_epsilon=MAX_EPS, logdir=logdir, epochs=150) 33 | 34 | 35 | for target_epoch in TARGET_EPOCHS: 36 | for bs in BATCH_SIZES: 37 | for group in GROUPS: 38 | mul = get_noise_mul(50000, bs, TARGET_EPS, target_epoch) 39 | 40 | for base_lr in BASE_LRS: 41 | lr = (bs // 512) * base_lr 42 | 43 | print(f"epoch={target_epoch}, bs={bs}, GN={group}, lr={base_lr}*{bs//512}={lr}, mul={mul}") 44 | logdir = f"logs/baselines/cifar10/bs={bs}_lr={lr}_mul={mul:.2f}_GN={group}" 45 | 46 | main(dataset="cifar10", max_grad_norm=MAX_GRAD_NORM, 47 | lr=lr, batch_size=bs, noise_multiplier=mul, 48 | input_norm="GroupNorm", num_groups=group, 49 | max_epsilon=MAX_EPS, logdir=logdir, epochs=150) 50 | -------------------------------------------------------------------------------- /scripts/run_cnns_cifar10.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from cnns import main 4 | from dp_utils import get_noise_mul 5 | 6 | MAX_GRAD_NORM = 0.1 7 | MAX_EPS = 5 8 | 9 | BATCH_SIZES = [512, 1024, 2048, 4096, 8192, 16384] 10 | BASE_LRS = [0.125, 0.25, 0.5, 1.0] 11 | #BASE_LRS += [2.0, 4.0, 8.0] 12 | 13 | TARGET_EPS = 3 14 | TARGET_EPOCHS = [30, 60, 120] 15 | 16 | for target_epoch in TARGET_EPOCHS: 17 | for base_lr in BASE_LRS: 18 | for bs in BATCH_SIZES: 19 | lr = (bs // 512) * base_lr 20 | mul = get_noise_mul(50000, bs, TARGET_EPS, target_epoch) 21 | 22 | print(f"epoch={target_epoch}, bs={bs}, lr={base_lr}*{bs//512}={lr}, mul={mul}") 23 | 24 | logdir = f"logs/cnns/cifar10/bs={bs}_lr={lr}_mul={mul:.2f}" 25 | main(dataset="cifar10", max_grad_norm=MAX_GRAD_NORM, 26 | lr=lr, batch_size=bs, noise_multiplier=mul, 27 | max_epsilon=MAX_EPS, logdir=logdir, epochs=150) 28 | -------------------------------------------------------------------------------- /scripts/run_cnns_cifar10_scat.py: -------------------------------------------------------------------------------- 1 | 2 | from cnns import main 3 | from dp_utils import get_noise_mul, get_renyi_divergence 4 | 5 | MAX_GRAD_NORM = 0.1 6 | MAX_EPS = 3.5 7 | 8 | BATCH_SIZES = [512, 1024, 2048, 4096, 8192, 16384] 9 | BASE_LRS = [0.125, 0.25, 0.5, 1.0] 10 | 11 | TARGET_EPS = 3 12 | TARGET_EPOCHS = [30, 60, 120] 13 | 14 | BN_MULS = [6, 8] 15 | GN = [9, 27, 81] 16 | 17 | for target_epoch in TARGET_EPOCHS: 18 | for base_lr in BASE_LRS: 19 | for bs in BATCH_SIZES: 20 | for bn_mul in BN_MULS: 21 | rdp_norm = 2 * get_renyi_divergence(1.0, bn_mul) 22 | mul = get_noise_mul(50000, bs, TARGET_EPS, target_epoch, rdp_init=rdp_norm) 23 | lr = (bs // 512) * base_lr 24 | print(f"epoch={target_epoch}, bs={bs}, lr={base_lr}*{bs//512}={lr}, mul={mul}, bn={bn_mul}") 25 | logdir = f"logs/cnns+scat/cifar10/bs={bs}_lr={lr}_mul={mul:.2f}_bn={bn_mul}" 26 | main(dataset="cifar10", max_grad_norm=MAX_GRAD_NORM, 27 | lr=lr, batch_size=bs, noise_multiplier=mul, 28 | use_scattering=True, input_norm="BN", bn_noise_multiplier=bn_mul, 29 | max_epsilon=MAX_EPS, logdir=logdir, epochs=int(1.25*target_epoch)) 30 | 31 | for target_epoch in TARGET_EPOCHS: 32 | for base_lr in BASE_LRS: 33 | for bs in BATCH_SIZES: 34 | for group in GN: 35 | mul = get_noise_mul(50000, bs, TARGET_EPS, target_epoch, rdp_init=0) 36 | lr = (bs // 512) * base_lr 37 | print(f"epoch={target_epoch}, bs={bs}, lr={base_lr}*{bs//512}={lr}, mul={mul}, GN={group}") 38 | logdir = f"logs/cnns+scat/cifar10/bs={bs}_lr={lr}_mul={mul:.2f}_GN={group}" 39 | main(dataset="cifar10", max_grad_norm=MAX_GRAD_NORM, 40 | lr=lr, batch_size=bs, noise_multiplier=mul, 41 | use_scattering=True, input_norm="GroupNorm", num_groups=group, 42 | max_epsilon=MAX_EPS, logdir=logdir, epochs=int(1.25*target_epoch)) 43 | -------------------------------------------------------------------------------- /tiny_images.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | from opacus import PrivacyEngine 8 | 9 | from train_utils import get_device, train, test 10 | from data import get_data, SemiSupervisedSampler, get_scatter_transform, \ 11 | get_scattered_loader, get_scattered_dataset 12 | from models import CNNS, get_num_params, ScatterLinear 13 | from dp_utils import ORDERS, get_privacy_spent, get_renyi_divergence, scatter_normalization 14 | from log import Logger 15 | 16 | 17 | def main(tiny_images=None, model="cnn", augment=False, use_scattering=False, 18 | batch_size=2048, mini_batch_size=256, lr=1, lr_start=None, optim="SGD", 19 | momentum=0.9, noise_multiplier=1, max_grad_norm=0.1, 20 | epochs=100, bn_noise_multiplier=None, max_epsilon=None, 21 | data_size=550000, delta=1e-6, logdir=None): 22 | logger = Logger(logdir) 23 | 24 | device = get_device() 25 | 26 | bs = batch_size 27 | assert bs % mini_batch_size == 0 28 | n_acc_steps = bs // mini_batch_size 29 | 30 | train_data, test_data = get_data("cifar10", augment=augment) 31 | train_loader = torch.utils.data.DataLoader( 32 | train_data, batch_size=100, shuffle=False, num_workers=4, pin_memory=True) 33 | 34 | test_loader = torch.utils.data.DataLoader( 35 | test_data, batch_size=100, shuffle=False, num_workers=4, pin_memory=True) 36 | 37 | if isinstance(tiny_images, torch.utils.data.Dataset): 38 | train_data_aug = tiny_images 39 | else: 40 | print("loading tiny images...") 41 | train_data_aug, _ = get_data("cifar10_500K", augment=augment, 42 | aux_data_filename=tiny_images) 43 | 44 | scattering, K, (h, w) = None, None, (None, None) 45 | pre_scattered = False 46 | if use_scattering: 47 | scattering, K, (h, w) = get_scatter_transform("cifar10_500K") 48 | scattering.to(device) 49 | 50 | # if the whole data fits in memory, pre-compute the scattering 51 | if use_scattering and data_size <= 50000: 52 | loader = torch.utils.data.DataLoader(train_data_aug, batch_size=100, shuffle=False, num_workers=4) 53 | train_data_aug = get_scattered_dataset(loader, scattering, device, data_size) 54 | pre_scattered = True 55 | 56 | assert data_size <= len(train_data_aug) 57 | num_sup = min(data_size, 50000) 58 | num_batches = int(np.ceil(50000 / mini_batch_size)) # cifar-10 equivalent 59 | 60 | train_batch_sampler = SemiSupervisedSampler(data_size, num_batches, mini_batch_size) 61 | train_loader_aug = torch.utils.data.DataLoader(train_data_aug, 62 | batch_sampler=train_batch_sampler, 63 | num_workers=0 if pre_scattered else 4, 64 | pin_memory=not pre_scattered) 65 | 66 | rdp_norm = 0 67 | if model == "cnn": 68 | if use_scattering: 69 | save_dir = f"bn_stats/cifar10_500K" 70 | os.makedirs(save_dir, exist_ok=True) 71 | bn_stats, rdp_norm = scatter_normalization(train_loader, 72 | scattering, 73 | K, 74 | device, 75 | data_size, 76 | num_sup, 77 | noise_multiplier=bn_noise_multiplier, 78 | orders=ORDERS, 79 | save_dir=save_dir) 80 | model = CNNS["cifar10"](K, input_norm="BN", bn_stats=bn_stats) 81 | model = model.to(device) 82 | 83 | if not pre_scattered: 84 | model = nn.Sequential(scattering, model) 85 | else: 86 | model = CNNS["cifar10"](in_channels=3, internal_norm=False) 87 | 88 | elif model == "linear": 89 | save_dir = f"bn_stats/cifar10_500K" 90 | os.makedirs(save_dir, exist_ok=True) 91 | bn_stats, rdp_norm = scatter_normalization(train_loader, 92 | scattering, 93 | K, 94 | device, 95 | data_size, 96 | num_sup, 97 | noise_multiplier=bn_noise_multiplier, 98 | orders=ORDERS, 99 | save_dir=save_dir) 100 | model = ScatterLinear(K, (h, w), input_norm="BN", bn_stats=bn_stats) 101 | model = model.to(device) 102 | 103 | if not pre_scattered: 104 | model = nn.Sequential(scattering, model) 105 | else: 106 | raise ValueError(f"Unknown model {model}") 107 | model.to(device) 108 | 109 | if pre_scattered: 110 | test_loader = get_scattered_loader(test_loader, scattering, device) 111 | 112 | print(f"model has {get_num_params(model)} parameters") 113 | 114 | if optim == "SGD": 115 | optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum) 116 | else: 117 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 118 | 119 | privacy_engine = PrivacyEngine( 120 | model, 121 | sample_rate=bs / data_size, 122 | alphas=ORDERS, 123 | noise_multiplier=noise_multiplier, 124 | max_grad_norm=max_grad_norm, 125 | ) 126 | privacy_engine.attach(optimizer) 127 | 128 | best_acc = 0 129 | flat_count = 0 130 | 131 | for epoch in range(0, epochs): 132 | 133 | print(f"\nEpoch: {epoch} ({privacy_engine.steps} steps)") 134 | train_loss, train_acc = train(model, train_loader_aug, optimizer, n_acc_steps=n_acc_steps) 135 | test_loss, test_acc = test(model, test_loader) 136 | 137 | if noise_multiplier > 0: 138 | print(f"sample_rate={privacy_engine.sample_rate}, " 139 | f"mul={privacy_engine.noise_multiplier}, " 140 | f"steps={privacy_engine.steps}") 141 | rdp_sgd = get_renyi_divergence( 142 | privacy_engine.sample_rate, privacy_engine.noise_multiplier 143 | ) * privacy_engine.steps 144 | epsilon, _ = get_privacy_spent(rdp_norm + rdp_sgd, target_delta=delta) 145 | epsilon2, _ = get_privacy_spent(rdp_sgd, target_delta=delta) 146 | print(f"ε = {epsilon:.3f} (sgd only: ε = {epsilon2:.3f})") 147 | 148 | if max_epsilon is not None and epsilon >= max_epsilon: 149 | return 150 | else: 151 | epsilon = None 152 | 153 | logger.log_epoch(epoch, train_loss, train_acc, test_loss, test_acc, epsilon) 154 | logger.log_scalar("epsilon/train", epsilon, epoch) 155 | logger.log_scalar("cifar10k_loss/train", train_loss, epoch) 156 | logger.log_scalar("cifar10k_acc/train", train_acc, epoch) 157 | 158 | if test_acc > best_acc: 159 | best_acc = test_acc 160 | flat_count = 0 161 | else: 162 | flat_count += 1 163 | if flat_count >= 20: 164 | print("plateau...") 165 | return 166 | 167 | 168 | if __name__ == '__main__': 169 | parser = argparse.ArgumentParser() 170 | parser.add_argument('--augment', action="store_true") 171 | parser.add_argument('--batch_size', type=int, default=128) 172 | parser.add_argument('--optim', type=str, default="SGD", choices=["SGD", "Adam"]) 173 | parser.add_argument('--lr', type=float, default=0.01) 174 | parser.add_argument('--lr_start', type=float, default=None) 175 | parser.add_argument('--momentum', type=float, default=0.9) 176 | parser.add_argument('--noise_multiplier', type=float, default=0) 177 | parser.add_argument('--max_grad_norm', type=float, default=0.1) 178 | parser.add_argument('--epochs', type=int, default=100) 179 | parser.add_argument('--model', choices=["cnn", "resnet", "linear"], default="cnn") 180 | parser.add_argument('--tiny_images', default="ti_500K_pseudo_labeled.pickle") 181 | parser.add_argument('--use_scattering', action="store_true") 182 | parser.add_argument('--bn_noise_multiplier', type=float, default=0) 183 | parser.add_argument('--logdir', default=None) 184 | parser.add_argument('--data_size', type=int, default=550_000) 185 | parser.add_argument('--delta', type=float, default=1e-6) 186 | args = parser.parse_args() 187 | main(**vars(args)) 188 | -------------------------------------------------------------------------------- /train_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def get_device(): 6 | use_cuda = torch.cuda.is_available() 7 | assert use_cuda 8 | device = torch.device("cuda" if use_cuda else "cpu") 9 | return device 10 | 11 | 12 | def train(model, train_loader, optimizer, n_acc_steps=1): 13 | device = next(model.parameters()).device 14 | model.train() 15 | num_examples = 0 16 | correct = 0 17 | train_loss = 0 18 | 19 | rem = len(train_loader) % n_acc_steps 20 | num_batches = len(train_loader) 21 | num_batches -= rem 22 | 23 | bs = train_loader.batch_size if train_loader.batch_size is not None else train_loader.batch_sampler.batch_size 24 | print(f"training on {num_batches} batches of size {bs}") 25 | 26 | for batch_idx, (data, target) in enumerate(train_loader): 27 | 28 | if batch_idx > num_batches - 1: 29 | break 30 | 31 | data, target = data.to(device), target.to(device) 32 | 33 | output = model(data) 34 | 35 | loss = F.cross_entropy(output, target) 36 | loss.backward() 37 | 38 | if ((batch_idx + 1) % n_acc_steps == 0) or ((batch_idx + 1) == len(train_loader)): 39 | optimizer.step() 40 | optimizer.zero_grad() 41 | else: 42 | with torch.no_grad(): 43 | # accumulate per-example gradients but don't take a step yet 44 | optimizer.virtual_step() 45 | 46 | pred = output.max(1, keepdim=True)[1] 47 | correct += pred.eq(target.view_as(pred)).sum().item() 48 | train_loss += F.cross_entropy(output, target, reduction='sum').item() 49 | num_examples += len(data) 50 | 51 | train_loss /= num_examples 52 | train_acc = 100. * correct / num_examples 53 | 54 | print(f'Train set: Average loss: {train_loss:.4f}, ' 55 | f'Accuracy: {correct}/{num_examples} ({train_acc:.2f}%)') 56 | 57 | return train_loss, train_acc 58 | 59 | 60 | def test(model, test_loader): 61 | device = next(model.parameters()).device 62 | model.eval() 63 | num_examples = 0 64 | test_loss = 0 65 | correct = 0 66 | 67 | with torch.no_grad(): 68 | for data, target in test_loader: 69 | data, target = data.to(device), target.to(device) 70 | output = model(data) 71 | test_loss += F.cross_entropy(output, target, reduction='sum').item() 72 | pred = output.max(1, keepdim=True)[1] 73 | correct += pred.eq(target.view_as(pred)).sum().item() 74 | num_examples += len(data) 75 | 76 | test_loss /= num_examples 77 | test_acc = 100. * correct / num_examples 78 | 79 | print(f'Test set: Average loss: {test_loss:.4f}, ' 80 | f'Accuracy: {correct}/{num_examples} ({test_acc:.2f}%)') 81 | 82 | return test_loss, test_acc 83 | -------------------------------------------------------------------------------- /transfer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ftramer/Handcrafted-DP/6568e8f74d93b5e39c5e7a612962ddf34acaa22c/transfer/__init__.py -------------------------------------------------------------------------------- /transfer/extract_cifar100.py: -------------------------------------------------------------------------------- 1 | """ 2 | download model from https://github.com/bearpaw/pytorch-classification 3 | """ 4 | 5 | import torchvision.transforms as transforms 6 | import torchvision.datasets as datasets 7 | import torch 8 | from transfer.resnext import resnext 9 | import numpy as np 10 | import os 11 | from sklearn.linear_model import LogisticRegression 12 | 13 | transform_test = transforms.Compose([ 14 | transforms.ToTensor(), 15 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 16 | ]) 17 | 18 | trainset = datasets.CIFAR100(root='.data', train=True, download=True, transform=transform_test) 19 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=200, shuffle=False, num_workers=4) 20 | 21 | testset = datasets.CIFAR100(root='.data', train=False, download=True, transform=transform_test) 22 | testloader = torch.utils.data.DataLoader(testset, batch_size=200, shuffle=False, num_workers=4) 23 | 24 | model = resnext( 25 | cardinality=8, 26 | num_classes=100, 27 | depth=29, 28 | widen_factor=4, 29 | dropRate=0, 30 | ) 31 | 32 | model = torch.nn.DataParallel(model).cuda() 33 | model.eval() 34 | 35 | checkpoint = torch.load("transfer/resnext-8x64d/model_best.pth.tar") 36 | model.load_state_dict(checkpoint['state_dict']) 37 | 38 | with torch.no_grad(): 39 | acc = 0.0 40 | for batch_idx, (inputs, targets) in enumerate(testloader): 41 | inputs = inputs.cuda() 42 | outputs = torch.argmax(model(inputs), dim=-1) 43 | 44 | acc += torch.sum(outputs.cpu().eq(targets)) 45 | 46 | acc /= (1.0 * len(testset)) 47 | acc = (100 * acc).numpy() 48 | print(f"Test Acc on CIFAR 100 = {acc: .2f}") 49 | 50 | model.module.classifier = torch.nn.Identity() 51 | 52 | features_cifar100_train = [] 53 | with torch.no_grad(): 54 | for batch_idx, (inputs, targets) in enumerate(trainloader): 55 | f = model(inputs.cuda()).cpu().numpy() 56 | features_cifar100_train.append(f) 57 | 58 | features_cifar100_train = np.concatenate(features_cifar100_train, axis=0) 59 | print(features_cifar100_train.shape) 60 | 61 | mean_cifar100 = np.mean(features_cifar100_train, axis=0) 62 | var_cifar100 = np.var(features_cifar100_train, axis=0) 63 | 64 | trainset = datasets.CIFAR10(root='.data', train=True, download=True, transform=transform_test) 65 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=200, shuffle=False, num_workers=4) 66 | 67 | testset = datasets.CIFAR10(root='.data', train=False, download=True, transform=transform_test) 68 | testloader = torch.utils.data.DataLoader(testset, batch_size=200, shuffle=False, num_workers=4) 69 | 70 | ytrain = np.asarray(trainset.targets).reshape(-1) 71 | ytest = np.asarray(testset.targets).reshape(-1) 72 | 73 | features_train = [] 74 | with torch.no_grad(): 75 | for batch_idx, (inputs, targets) in enumerate(trainloader): 76 | f = model(inputs.cuda()).cpu().numpy() 77 | features_train.append(f) 78 | 79 | features_train = np.concatenate(features_train, axis=0) 80 | print(features_train.shape) 81 | 82 | features_test = [] 83 | with torch.no_grad(): 84 | for batch_idx, (inputs, targets) in enumerate(testloader): 85 | f = model(inputs.cuda()).cpu().numpy() 86 | features_test.append(f) 87 | 88 | features_test = np.concatenate(features_test, axis=0) 89 | print(features_test.shape) 90 | 91 | os.makedirs("transfer/features/", exist_ok=True) 92 | np.save("transfer/features/cifar100_resnext_train.npy", features_train) 93 | np.save("transfer/features/cifar100_resnext_test.npy", features_test) 94 | np.save("transfer/features/cifar100_resnext_mean.npy", mean_cifar100) 95 | np.save("transfer/features/cifar100_resnext_var.npy", var_cifar100) 96 | 97 | mean = np.mean(features_train, axis=0) 98 | var = np.var(features_train, axis=0) 99 | 100 | features_train_norm = (features_train - mean) / np.sqrt(var + 1e-5) 101 | features_test_norm = (features_test - mean) / np.sqrt(var + 1e-5) 102 | 103 | features_train_norm2 = (features_train - mean_cifar100) / np.sqrt(var_cifar100 + 1e-5) 104 | features_test_norm2 = (features_test - mean_cifar100) / np.sqrt(var_cifar100 + 1e-5) 105 | 106 | for C in [0.01, 0.1, 1.0, 10.0, 100.0]: 107 | clf = LogisticRegression(random_state=0, max_iter=1000, C=C).fit(features_train, ytrain) 108 | print(C, clf.score(features_train, ytrain), clf.score(features_test, ytest)) 109 | 110 | clf = LogisticRegression(random_state=0, max_iter=1000, C=C).fit(features_train_norm, ytrain) 111 | print(C, clf.score(features_train_norm, ytrain), clf.score(features_test_norm, ytest)) 112 | 113 | clf = LogisticRegression(random_state=0, max_iter=1000, C=C).fit(features_train_norm2, ytrain) 114 | print(C, clf.score(features_train_norm2, ytrain), clf.score(features_test_norm2, ytest)) 115 | 116 | -------------------------------------------------------------------------------- /transfer/extract_simclr.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from https://colab.research.google.com/github/google-research/simclr/blob/master/colabs/finetuning.ipynb 3 | """ 4 | 5 | import numpy as np 6 | 7 | import tensorflow.compat.v1 as tf 8 | tf.disable_eager_execution() 9 | import tensorflow_hub as hub 10 | from sklearn.linear_model import LogisticRegression 11 | import os 12 | 13 | 14 | CROP_PROPORTION = 0.875 # Standard for ImageNet. 15 | 16 | 17 | def random_apply(func, p, x): 18 | """Randomly apply function func to x with probability p.""" 19 | return tf.cond( 20 | tf.less(tf.random_uniform([], minval=0, maxval=1, dtype=tf.float32), 21 | tf.cast(p, tf.float32)), 22 | lambda: func(x), 23 | lambda: x) 24 | 25 | 26 | def _compute_crop_shape( 27 | image_height, image_width, aspect_ratio, crop_proportion): 28 | """Compute aspect ratio-preserving shape for central crop. 29 | The resulting shape retains `crop_proportion` along one side and a proportion 30 | less than or equal to `crop_proportion` along the other side. 31 | Args: 32 | image_height: Height of image to be cropped. 33 | image_width: Width of image to be cropped. 34 | aspect_ratio: Desired aspect ratio (width / height) of output. 35 | crop_proportion: Proportion of image to retain along the less-cropped side. 36 | Returns: 37 | crop_height: Height of image after cropping. 38 | crop_width: Width of image after cropping. 39 | """ 40 | image_width_float = tf.cast(image_width, tf.float32) 41 | image_height_float = tf.cast(image_height, tf.float32) 42 | 43 | def _requested_aspect_ratio_wider_than_image(): 44 | crop_height = tf.cast(tf.rint( 45 | crop_proportion / aspect_ratio * image_width_float), tf.int32) 46 | crop_width = tf.cast(tf.rint( 47 | crop_proportion * image_width_float), tf.int32) 48 | return crop_height, crop_width 49 | 50 | def _image_wider_than_requested_aspect_ratio(): 51 | crop_height = tf.cast( 52 | tf.rint(crop_proportion * image_height_float), tf.int32) 53 | crop_width = tf.cast(tf.rint( 54 | crop_proportion * aspect_ratio * 55 | image_height_float), tf.int32) 56 | return crop_height, crop_width 57 | 58 | return tf.cond( 59 | aspect_ratio > image_width_float / image_height_float, 60 | _requested_aspect_ratio_wider_than_image, 61 | _image_wider_than_requested_aspect_ratio) 62 | 63 | 64 | def center_crop(image, height, width, crop_proportion): 65 | """Crops to center of image and rescales to desired size. 66 | Args: 67 | image: Image Tensor to crop. 68 | height: Height of image to be cropped. 69 | width: Width of image to be cropped. 70 | crop_proportion: Proportion of image to retain along the less-cropped side. 71 | Returns: 72 | A `height` x `width` x channels Tensor holding a central crop of `image`. 73 | """ 74 | shape = tf.shape(image) 75 | image_height = shape[0] 76 | image_width = shape[1] 77 | crop_height, crop_width = _compute_crop_shape( 78 | image_height, image_width, height / width, crop_proportion) 79 | offset_height = ((image_height - crop_height) + 1) // 2 80 | offset_width = ((image_width - crop_width) + 1) // 2 81 | image = tf.image.crop_to_bounding_box( 82 | image, offset_height, offset_width, crop_height, crop_width) 83 | 84 | image = tf.image.resize_bicubic([image], [height, width])[0] 85 | 86 | return image 87 | 88 | 89 | def preprocess_for_eval(image, height, width, crop=True): 90 | """Preprocesses the given image for evaluation. 91 | Args: 92 | image: `Tensor` representing an image of arbitrary size. 93 | height: Height of output image. 94 | width: Width of output image. 95 | crop: Whether or not to (center) crop the test images. 96 | Returns: 97 | A preprocessed image `Tensor`. 98 | """ 99 | if crop: 100 | image = center_crop(image, height, width, crop_proportion=CROP_PROPORTION) 101 | image = tf.reshape(image, [height, width, 3]) 102 | image = tf.clip_by_value(image, 0., 1.) 103 | return image 104 | 105 | 106 | def preprocess_image(image, height, width, is_training=False, 107 | color_distort=True, test_crop=True): 108 | """Preprocesses the given image. 109 | Args: 110 | image: `Tensor` representing an image of arbitrary size. 111 | height: Height of output image. 112 | width: Width of output image. 113 | is_training: `bool` for whether the preprocessing is for training. 114 | color_distort: whether to apply the color distortion. 115 | test_crop: whether or not to extract a central crop of the images 116 | (as for standard ImageNet evaluation) during the evaluation. 117 | Returns: 118 | A preprocessed image `Tensor` of range [0, 1]. 119 | """ 120 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 121 | #if is_training: 122 | # return preprocess_for_train(image, height, width, color_distort) 123 | #else: 124 | return preprocess_for_eval(image, height, width, test_crop) 125 | 126 | 127 | (xtrain, ytrain), (xtest, ytest) = tf.keras.datasets.cifar10.load_data() 128 | xtrain = xtrain.astype(np.float32) / 255.0 129 | xtest = xtest.astype(np.float32) / 255.0 130 | ytrain = ytrain.reshape(-1) 131 | ytest = ytest.reshape(-1) 132 | 133 | 134 | def _preprocess(x): 135 | x = preprocess_image(x, 224, 224, is_training=False, color_distort=False) 136 | return x 137 | 138 | 139 | batch_size = 100 140 | x = tf.placeholder(shape=(batch_size, 32, 32, 3), dtype=tf.float32) 141 | x_preproc = tf.map_fn(_preprocess, x) 142 | print(x_preproc.get_shape().as_list()) 143 | 144 | hub_path = 'gs://simclr-checkpoints/simclrv2/pretrained/r50_2x_sk1/hub/' 145 | module = hub.Module(hub_path, trainable=False) 146 | features = module(inputs=x_preproc, signature='default') 147 | print(features.get_shape().as_list()) 148 | 149 | sess = tf.Session() 150 | sess.run(tf.global_variables_initializer()) 151 | print("model loaded!") 152 | 153 | features_train = [] 154 | for i in range(len(xtrain)//batch_size): 155 | x_batch = xtrain[i*batch_size:(i+1)*batch_size] 156 | f = sess.run(features, feed_dict={x: x_batch}) 157 | features_train.append(f) 158 | 159 | features_train = np.concatenate(features_train, axis=0) 160 | print(features_train.shape) 161 | 162 | features_test = [] 163 | for i in range(len(xtest)//batch_size): 164 | x_batch = xtest[i*batch_size:(i+1)*batch_size] 165 | f = sess.run(features, feed_dict={x: x_batch}) 166 | features_test.append(f) 167 | 168 | features_test = np.concatenate(features_test, axis=0) 169 | print(features_test.shape) 170 | 171 | os.makedirs("transfer/features/", exist_ok=True) 172 | np.save("transfer/features/simclr_r50_2x_sk1_train.npy", features_train) 173 | np.save("transfer/features/simclr_r50_2x_sk1_test.npy", features_test) 174 | 175 | mean = np.mean(features_train, axis=0) 176 | var = np.var(features_train, axis=0) 177 | 178 | features_train_norm = (features_train - mean) / np.sqrt(var + 1e-5) 179 | features_test_norm = (features_test - mean) / np.sqrt(var + 1e-5) 180 | 181 | for C in [0.001, 0.01, 0.1, 1.0, 10.0, 100.0]: 182 | clf = LogisticRegression(random_state=0, max_iter=1000, C=C).fit(features_train, ytrain) 183 | print(C, clf.score(features_train, ytrain), clf.score(features_test, ytest)) 184 | 185 | clf = LogisticRegression(random_state=0, max_iter=1000, C=C).fit(features_train_norm, ytrain) 186 | print(C, clf.score(features_train_norm, ytrain), clf.score(features_test_norm, ytest)) 187 | -------------------------------------------------------------------------------- /transfer/resnext.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | """ 3 | Creates a ResNeXt Model as defined in: 4 | Xie, S., Girshick, R., Dollar, P., Tu, Z., & He, K. (2016). 5 | Aggregated residual transformations for deep neural networks. 6 | arXiv preprint arXiv:1611.05431. 7 | import from https://github.com/prlz77/ResNeXt.pytorch/blob/master/models/model.py 8 | """ 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn import init 12 | 13 | __all__ = ['resnext'] 14 | 15 | 16 | class ResNeXtBottleneck(nn.Module): 17 | """ 18 | RexNeXt bottleneck type C (https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua) 19 | """ 20 | def __init__(self, in_channels, out_channels, stride, cardinality, widen_factor): 21 | """ Constructor 22 | Args: 23 | in_channels: input channel dimensionality 24 | out_channels: output channel dimensionality 25 | stride: conv stride. Replaces pooling layer. 26 | cardinality: num of convolution groups. 27 | widen_factor: factor to reduce the input dimensionality before convolution. 28 | """ 29 | super(ResNeXtBottleneck, self).__init__() 30 | D = cardinality * out_channels // widen_factor 31 | self.conv_reduce = nn.Conv2d(in_channels, D, kernel_size=1, stride=1, padding=0, bias=False) 32 | self.bn_reduce = nn.BatchNorm2d(D) 33 | self.conv_conv = nn.Conv2d(D, D, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 34 | self.bn = nn.BatchNorm2d(D) 35 | self.conv_expand = nn.Conv2d(D, out_channels, kernel_size=1, stride=1, padding=0, bias=False) 36 | self.bn_expand = nn.BatchNorm2d(out_channels) 37 | 38 | self.shortcut = nn.Sequential() 39 | if in_channels != out_channels: 40 | self.shortcut.add_module('shortcut_conv', nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=False)) 41 | self.shortcut.add_module('shortcut_bn', nn.BatchNorm2d(out_channels)) 42 | 43 | def forward(self, x): 44 | bottleneck = self.conv_reduce(x) 45 | bottleneck = F.relu(self.bn_reduce(bottleneck), inplace=True) 46 | bottleneck = self.conv_conv(bottleneck) 47 | bottleneck = F.relu(self.bn(bottleneck), inplace=True) 48 | bottleneck = self.conv_expand(bottleneck) 49 | bottleneck = self.bn_expand(bottleneck) 50 | residual = self.shortcut(x) 51 | return F.relu(residual + bottleneck, inplace=True) 52 | 53 | def train(self, mode=True, freeze_DP=True): 54 | super(ResNeXtBottleneck, self).train(mode) 55 | 56 | if freeze_DP: 57 | self.conv_conv.eval() 58 | self.conv_conv.weight.requires_grad = False 59 | for m in self.modules(): 60 | if isinstance(m, nn.BatchNorm2d): 61 | m.weight.requires_grad = False 62 | m.bias.requires_grad = False 63 | 64 | 65 | class CifarResNeXt(nn.Module): 66 | """ 67 | ResNext optimized for the Cifar dataset, as specified in 68 | https://arxiv.org/pdf/1611.05431.pdf 69 | """ 70 | def __init__(self, cardinality, depth, num_classes, widen_factor=4, dropRate=0): 71 | """ Constructor 72 | Args: 73 | cardinality: number of convolution groups. 74 | depth: number of layers. 75 | num_classes: number of classes 76 | widen_factor: factor to adjust the channel dimensionality 77 | """ 78 | super(CifarResNeXt, self).__init__() 79 | self.cardinality = cardinality 80 | self.depth = depth 81 | self.block_depth = (self.depth - 2) // 9 82 | self.widen_factor = widen_factor 83 | self.num_classes = num_classes 84 | self.output_size = 64 85 | self.stages = [64, 64 * self.widen_factor, 128 * self.widen_factor, 256 * self.widen_factor] 86 | 87 | self.conv_1_3x3 = nn.Conv2d(3, 64, 3, 1, 1, bias=False) 88 | self.bn_1 = nn.BatchNorm2d(64) 89 | self.stage_1 = self.block('stage_1', self.stages[0], self.stages[1], 1) 90 | self.stage_2 = self.block('stage_2', self.stages[1], self.stages[2], 2) 91 | self.stage_3 = self.block('stage_3', self.stages[2], self.stages[3], 2) 92 | self.classifier = nn.Linear(1024, num_classes) 93 | init.kaiming_normal(self.classifier.weight) 94 | 95 | for key in self.state_dict(): 96 | if key.split('.')[-1] == 'weight': 97 | if 'conv' in key: 98 | init.kaiming_normal(self.state_dict()[key], mode='fan_out') 99 | if 'bn' in key: 100 | self.state_dict()[key][...] = 1 101 | elif key.split('.')[-1] == 'bias': 102 | self.state_dict()[key][...] = 0 103 | 104 | def block(self, name, in_channels, out_channels, pool_stride=2): 105 | """ Stack n bottleneck modules where n is inferred from the depth of the network. 106 | Args: 107 | name: string name of the current block. 108 | in_channels: number of input channels 109 | out_channels: number of output channels 110 | pool_stride: factor to reduce the spatial dimensionality in the first bottleneck of the block. 111 | Returns: a Module consisting of n sequential bottlenecks. 112 | """ 113 | block = nn.Sequential() 114 | for bottleneck in range(self.block_depth): 115 | name_ = '%s_bottleneck_%d' % (name, bottleneck) 116 | if bottleneck == 0: 117 | block.add_module(name_, ResNeXtBottleneck(in_channels, out_channels, pool_stride, self.cardinality, 118 | self.widen_factor)) 119 | else: 120 | block.add_module(name_, 121 | ResNeXtBottleneck(out_channels, out_channels, 1, self.cardinality, self.widen_factor)) 122 | return block 123 | 124 | def forward(self, x): 125 | x = self.conv_1_3x3(x) 126 | x = F.relu(self.bn_1(x), inplace=True) 127 | x = self.stage_1(x) 128 | x = self.stage_2(x) 129 | x = self.stage_3(x) 130 | x = F.avg_pool2d(x, 8, 1) 131 | x = x.view(-1, 1024) 132 | return self.classifier(x) 133 | 134 | def train(self, mode=True, freeze_DP=True): 135 | super(CifarResNeXt, self).train(mode) 136 | if freeze_DP: 137 | for m in self.modules(): 138 | if isinstance(m, ResNeXtBottleneck): 139 | m.train(mode, freeze_DP) 140 | 141 | if isinstance(m, nn.BatchNorm2d): 142 | m.weight.requires_grad = False 143 | m.bias.requires_grad = False 144 | 145 | 146 | def resnext(**kwargs): 147 | """Constructs a ResNeXt. 148 | """ 149 | model = CifarResNeXt(**kwargs) 150 | return model 151 | -------------------------------------------------------------------------------- /transfer/transfer_cifar.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | from opacus import PrivacyEngine 7 | 8 | from models import StandardizeLayer 9 | from train_utils import get_device, train, test 10 | from data import get_data 11 | from dp_utils import ORDERS, get_privacy_spent, get_renyi_divergence 12 | from log import Logger 13 | 14 | 15 | def main(feature_path=None, batch_size=2048, mini_batch_size=256, 16 | lr=1, optim="SGD", momentum=0.9, nesterov=False, noise_multiplier=1, 17 | max_grad_norm=0.1, max_epsilon=None, epochs=100, logdir=None): 18 | 19 | logger = Logger(logdir) 20 | 21 | device = get_device() 22 | 23 | # get pre-computed features 24 | x_train = np.load(f"{feature_path}_train.npy") 25 | x_test = np.load(f"{feature_path}_test.npy") 26 | 27 | train_data, test_data = get_data("cifar10", augment=False) 28 | y_train = np.asarray(train_data.targets) 29 | y_test = np.asarray(test_data.targets) 30 | 31 | trainset = torch.utils.data.TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train)) 32 | testset = torch.utils.data.TensorDataset(torch.from_numpy(x_test), torch.from_numpy(y_test)) 33 | 34 | bs = batch_size 35 | assert bs % mini_batch_size == 0 36 | n_acc_steps = bs // mini_batch_size 37 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=mini_batch_size, shuffle=True, num_workers=1, pin_memory=True, drop_last=True) 38 | test_loader = torch.utils.data.DataLoader(testset, batch_size=mini_batch_size, shuffle=False, num_workers=1, pin_memory=True) 39 | 40 | n_features = x_train.shape[-1] 41 | try: 42 | mean = np.load(f"{feature_path}_mean.npy") 43 | var = np.load(f"{feature_path}_var.npy") 44 | except FileNotFoundError: 45 | mean = np.zeros(n_features, dtype=np.float32) 46 | var = np.ones(n_features, dtype=np.float32) 47 | 48 | bn_stats = (torch.from_numpy(mean).to(device), torch.from_numpy(var).to(device)) 49 | 50 | model = nn.Sequential(StandardizeLayer(bn_stats), nn.Linear(n_features, 10)).to(device) 51 | 52 | if optim == "SGD": 53 | optimizer = torch.optim.SGD(model.parameters(), lr=lr, 54 | momentum=momentum, 55 | nesterov=nesterov) 56 | else: 57 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 58 | 59 | privacy_engine = PrivacyEngine( 60 | model, 61 | sample_rate=bs / len(train_data), 62 | alphas=ORDERS, 63 | noise_multiplier=noise_multiplier, 64 | max_grad_norm=max_grad_norm, 65 | ) 66 | privacy_engine.attach(optimizer) 67 | 68 | for epoch in range(0, epochs): 69 | print(f"\nEpoch: {epoch}") 70 | 71 | train_loss, train_acc = train(model, train_loader, optimizer, n_acc_steps=n_acc_steps) 72 | test_loss, test_acc = test(model, test_loader) 73 | 74 | if noise_multiplier > 0: 75 | rdp_sgd = get_renyi_divergence( 76 | privacy_engine.sample_rate, privacy_engine.noise_multiplier 77 | ) * privacy_engine.steps 78 | epsilon, _ = get_privacy_spent(rdp_sgd) 79 | print(f"ε = {epsilon:.3f}") 80 | 81 | if max_epsilon is not None and epsilon >= max_epsilon: 82 | return 83 | else: 84 | epsilon = None 85 | 86 | logger.log_epoch(epoch, train_loss, train_acc, test_loss, test_acc, epsilon) 87 | 88 | 89 | if __name__ == '__main__': 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument('--batch_size', type=int, default=256) 92 | parser.add_argument('--lr', type=float, default=0.01) 93 | parser.add_argument('--optim', type=str, default="SGD", choices=["SGD", "Adam"]) 94 | parser.add_argument('--momentum', type=float, default=0.9) 95 | parser.add_argument('--nesterov', action="store_true") 96 | parser.add_argument('--noise_multiplier', type=float, default=1) 97 | parser.add_argument('--max_grad_norm', type=float, default=0.1) 98 | parser.add_argument('--epochs', type=int, default=50) 99 | parser.add_argument('--feature_path', default=None) 100 | parser.add_argument('--max_epsilon', type=float, default=None) 101 | parser.add_argument('--logdir', default=None) 102 | args = parser.parse_args() 103 | main(**vars(args)) 104 | --------------------------------------------------------------------------------