├── LICENSE ├── README.md ├── attacker_threshold.py ├── attackers.py ├── base_model.py ├── cifar10_models ├── __init__.py ├── densenet.py ├── googlenet.py ├── inception.py ├── mobilenetv2.py ├── resnet.py ├── resnet_orig.py └── vgg.py ├── config ├── chmnist_dense.json ├── chmnist_resnet18.json ├── chmnist_vgg16.json ├── cifar100_dense.json ├── cifar100_resnet18.json ├── cifar100_vgg16.json ├── cifar10_dense.json ├── cifar10_resnet18.json ├── cifar10_vgg16.json ├── location.json ├── purchase.json ├── svhn_dense.json ├── svhn_resnet18.json ├── svhn_vgg16.json └── texas100.json ├── datasets.py ├── divergence_plot.ipynb ├── figs └── attack_pipeline.png ├── mia.py ├── models.py ├── pretrain.py ├── prune.py ├── prune_dp.py ├── pruner.py ├── pyvacy ├── __init__.py ├── analysis │ ├── __init__.py │ ├── epsilon_calculation.py │ ├── moments_accountant.py │ ├── rdp_accountant.py │ └── subsampled.py ├── optim │ ├── __init__.py │ └── dp_optimizer.py └── sampling │ ├── __init__.py │ └── batch_samplers.py ├── requirements.txt ├── transformer.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Machine-Learning-Security-Lab (Michigan Tech) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [comment]: <> (# mia_prune) 2 | # Membership Inference Attacks and Defenses in Neural Network Pruning 3 | 4 | This repository accompanies the paper [Membership Inference Attacks and Defenses in Neural Network Pruning](https://www.usenix.org/conference/usenixsecurity22/presentation/yuan-xiaoyong), 5 | accepted by USENIX Security 2022. The extended version can be found at [arXiv](https://arxiv.org/abs/2202.03335). 6 | The repository contains the main code of membership inference attacks and defenses in neural network pruning. 7 | The code is tested on Python 3.8, Pytorch 1.8.1, and Ubuntu 18.04. 8 | GPUs are needed to accelerate neural network training and membership inference attacks. 9 | 10 | ![Attack Pipeline](figs/attack_pipeline.png) 11 | 12 | # Background 13 | Neural network pruning has been an essential technique to reduce the computation and memory requirements for using 14 | deep neural networks for resource-constrained devices. 15 | We investigated the membership inference attacks (MIA) and the countermeasures in neural network pruning. 16 | We proposed a membership inference attack, namely self-attention membership inference attack (SAMIA), 17 | targeted at pruned neural networks, as well as a pair-based posterior balancing (PPB) defense method. 18 | 19 | # Installation 20 | Get the repository: 21 | ``` 22 | git clone https://github.com/Machine-Learning-Security-Lab/mia_prune 23 | cd mia_prune 24 | ``` 25 | Install Python packages. 26 | ``` 27 | pip install -r requirements.txt 28 | ``` 29 | Create a folder for storing datasets. The data folder location can be updated in `datasets.py`. 30 | ``` 31 | mkdir -p data/datasets 32 | ``` 33 | Create a folder for storing the models. 34 | ``` 35 | mkdir results 36 | ``` 37 | 38 | # Usage 39 | 40 | ## Attacks 41 | 1. Train an original neural network: 42 | ``` 43 | python pretrain.py [GPU-ID] [config_path] 44 | ``` 45 | 2. Prune the model and fine-tune the model 46 | ``` 47 | python prune.py [GPU-ID] [config_path] --pruner_name [pruner_name] --prune_sparsity [prune_sparsity] 48 | ``` 49 | 3. Conduct membership inference attacks on the pruned model. 50 | ``` 51 | python mia.py [GPU-ID] [config_path] --pruner_name [pruner_name] --prune_sparsity [prune_sparsity] --attacks [attacks] 52 | ``` 53 | 4. Conduct membership inference attacks on the original model. 54 | ``` 55 | python mia.py [GPU-ID] [config_path] --attacks [attacks] --original 56 | ``` 57 | 58 | ## Defenses 59 | 1. Based on an original model, prune the model and fine-tune the model with a defense method and its arguments. 60 | ``` 61 | python prune.py [GPU-ID] [config_path] --pruner_name [pruner_name] --prune_sparsity [prune_sparsity] 62 | --defend [defend] --defend_arg [defend_arg] 63 | ``` 64 | 2. Conduct membership inference attacks on the pruned model with defense. 65 | ``` 66 | python mia.py [GPU-ID] [config_path] --pruner_name [pruner_name] --prune_sparsity [prune_sparsity] --attacks [attacks] 67 | --defend [defend] --defend_arg [defend_arg] 68 | ``` 69 | 3. Conduct membership inference attacks on the pruned model with defense when the attacker knows the defense (adaptive attack). 70 | ``` 71 | python mia.py [GPU-ID] [config_path] --pruner_name [pruner_name] --prune_sparsity [prune_sparsity] --attacks [attacks] 72 | --defend [defend] --defend_arg [defend_arg] --adaptive 73 | ``` 74 | 75 | ## Argument options 76 | - `config_path` is the path of files in the `config` folder to get the information of the dataset and neural network architecture. 77 | - `pruner_name` can be `l1unstructure` (default), `l1structure`, `l2structure`, `slim`. 78 | - `prune_sparsity` can be any float values in (0, 1), default 0.7. 79 | - `attacks` can be `samia` (default), `threshold`, `nn`, `nn_top3`, `nn_cls`. Multiple attacks can be concatenated. 80 | E.g., `--attacks samia,nn,nn_top3`. 81 | `threshold` attack (modified from https://github.com/inspire-group/membership-inference-evaluation) 82 | performs several threshold-based attacks including 83 | - Ground-truth class confidence-based threshold attack (Conf). 84 | - Cross-entropy-based threshold attack (Xent). 85 | - Modified-entropy-based threshold attack (Mentr). 86 | - Top1 Confidence-based threshold attack (Top1-conf). 87 | - `defend` can be `""` (basic defense, default), `ppb` (PPB defense), or `adv` (Adversarial Regularization). 88 | To run DP defense, please use `prune_dp.py`, where we use [pyvacy](https://github.com/ChrisWaites/pyvacy) to run 89 | DPSGD in fine-tuning. 90 | 91 | # Examples 92 | Train a CIFAR10 model using ResNet18 model on GPU0. 93 | ``` 94 | python pretrain.py 0 config/cifar10_resnet18.json 95 | ``` 96 | Prune the model using l1unstructured pruning with sparsity level 70% (remove 70% parameters). 97 | ``` 98 | python prune.py 0 config/cifar10_resnet18.json --pruner_name l1unstructure --prune_sparsity 0.7 99 | ``` 100 | Attack the pruned model using SAMIA. 101 | ``` 102 | python mia.py 0 config/cifar10_resnet18.json --attacks samia 103 | ``` 104 | Attack the original model using SAMIA. 105 | ``` 106 | python mia.py 0 config/cifar10_resnet18.json --attacks samia --original 107 | ``` 108 | Prune the original model with PPB defense. 109 | ``` 110 | python prune.py 0 config/cifar10_resnet18.json --defend ppb --defend_arg 4 111 | ``` 112 | Attack the pruned model with defense using SAMIA. 113 | ``` 114 | python mia.py 0 config/cifar10_resnet18.json --attacks samia,threshold --defend ppb --defend_arg 4 115 | ``` 116 | Attack the pruned model with defense using SAMIA when the attacker knows the defense, e.g., adaptive attack. 117 | ``` 118 | python mia.py 0 config/cifar10_resnet18.json --attacks samia,threshold --defend ppb --defend_arg 4 ----adaptive 119 | ``` 120 | Launch multiple attacks. 121 | ``` 122 | python mia.py 0 config/cifar10_resnet18.json --attacks samia,nn,nn_top3 123 | ``` 124 | 125 | # Structure 126 | - [pretrain.py](pretrain.py) is used to train the original neural network. 127 | - [prune.py](prune.py) is used to prune and fine-tune the network with and without defenses. 128 | - [mia.py](mia.py) is used to conduct attacks on the pruned/original network with and without defenses. 129 | - [config](config) folder consists all the configuration files for various datasets and neural network architectures. 130 | - [transformer.py](transformer.py) consists the self-attention model used in SAMIA attack. 131 | - [base_model.py](base_model.py) includes the code for training, testing, and defense. 132 | 133 | 134 | # Citation 135 | ``` 136 | @inproceedings{yuan2022membership, 137 | title = {Membership Inference Attacks and Defenses in Neural Network Pruning}, 138 | booktitle = {31st USENIX Security Symposium (USENIX Security 22)}, 139 | author={Yuan, Xiaoyong and Zhang, Lan}, 140 | year={2022} 141 | } 142 | ``` 143 | -------------------------------------------------------------------------------- /attacker_threshold.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from https://github.com/inspire-group/membership-inference-evaluation 3 | """ 4 | import numpy as np 5 | 6 | class ThresholdAttacker: 7 | def __init__(self, shadow_train_performance, shadow_test_performance, target_train_performance, 8 | target_test_performance, num_classes): 9 | self.num_classes = num_classes 10 | 11 | self.s_tr_outputs, self.s_tr_labels = shadow_train_performance 12 | self.s_te_outputs, self.s_te_labels = shadow_test_performance 13 | self.t_tr_outputs, self.t_tr_labels = target_train_performance 14 | self.t_te_outputs, self.t_te_labels = target_test_performance 15 | 16 | self.s_tr_corr = (np.argmax(self.s_tr_outputs, axis=1) == self.s_tr_labels).astype(int) 17 | self.s_te_corr = (np.argmax(self.s_te_outputs, axis=1) == self.s_te_labels).astype(int) 18 | self.t_tr_corr = (np.argmax(self.t_tr_outputs, axis=1) == self.t_tr_labels).astype(int) 19 | self.t_te_corr = (np.argmax(self.t_te_outputs, axis=1) == self.t_te_labels).astype(int) 20 | 21 | self.s_tr_conf = np.array([self.s_tr_outputs[i, self.s_tr_labels[i]] for i in range(len(self.s_tr_labels))]) 22 | self.s_te_conf = np.array([self.s_te_outputs[i, self.s_te_labels[i]] for i in range(len(self.s_te_labels))]) 23 | self.t_tr_conf = np.array([self.t_tr_outputs[i, self.t_tr_labels[i]] for i in range(len(self.t_tr_labels))]) 24 | self.t_te_conf = np.array([self.t_te_outputs[i, self.t_te_labels[i]] for i in range(len(self.t_te_labels))]) 25 | 26 | self.s_tr_entr = self._entr_comp(self.s_tr_outputs) 27 | self.s_te_entr = self._entr_comp(self.s_te_outputs) 28 | self.t_tr_entr = self._entr_comp(self.t_tr_outputs) 29 | self.t_te_entr = self._entr_comp(self.t_te_outputs) 30 | 31 | self.s_tr_m_entr = self._m_entr_comp(self.s_tr_outputs, self.s_tr_labels) 32 | self.s_te_m_entr = self._m_entr_comp(self.s_te_outputs, self.s_te_labels) 33 | self.t_tr_m_entr = self._m_entr_comp(self.t_tr_outputs, self.t_tr_labels) 34 | self.t_te_m_entr = self._m_entr_comp(self.t_te_outputs, self.t_te_labels) 35 | 36 | def _log_value(self, probs, small_value=1e-30): 37 | return -np.log(np.maximum(probs, small_value)) 38 | 39 | def _entr_comp(self, probs): 40 | return np.sum(np.multiply(probs, self._log_value(probs)), axis=1) 41 | 42 | def _m_entr_comp(self, probs, true_labels): 43 | log_probs = self._log_value(probs) 44 | reverse_probs = 1 - probs 45 | log_reverse_probs = self._log_value(reverse_probs) 46 | modified_probs = np.copy(probs) 47 | modified_probs[range(true_labels.size), true_labels] = reverse_probs[range(true_labels.size), true_labels] 48 | modified_log_probs = np.copy(log_reverse_probs) 49 | modified_log_probs[range(true_labels.size), true_labels] = log_probs[range(true_labels.size), true_labels] 50 | return np.sum(np.multiply(modified_probs, modified_log_probs), axis=1) 51 | 52 | def _thre_setting(self, tr_values, te_values): 53 | value_list = np.concatenate((tr_values, te_values)) 54 | thre, max_acc = 0, 0 55 | for value in value_list: 56 | tr_ratio = np.sum(tr_values >= value) / (len(tr_values) + 0.0) 57 | te_ratio = np.sum(te_values < value) / (len(te_values) + 0.0) 58 | acc = 0.5 * (tr_ratio + te_ratio) 59 | if acc > max_acc: 60 | thre, max_acc = value, acc 61 | return thre 62 | 63 | def _mem_inf_via_corr(self): 64 | # perform membership inference attack based on whether the input is correctly classified or not 65 | t_tr_acc = np.sum(self.t_tr_corr) / (len(self.t_tr_corr) + 0.0) 66 | t_te_acc = np.sum(self.t_te_corr) / (len(self.t_te_corr) + 0.0) 67 | mem_inf_acc = 0.5 * (t_tr_acc + 1 - t_te_acc) 68 | print( 69 | 'For membership inference attack via correctness, the attack acc is {acc1:.3f}, with train acc {acc2:.3f} ' 70 | 'and test acc {acc3:.3f}'.format( 71 | acc1=mem_inf_acc, acc2=t_tr_acc, acc3=t_te_acc)) 72 | return mem_inf_acc 73 | 74 | def _mem_inf_thre(self, v_name, s_tr_values, s_te_values, t_tr_values, t_te_values): 75 | # perform membership inference attack by thresholding feature values: the feature can be prediction confidence, 76 | # (negative) prediction entropy, and (negative) modified entropy 77 | t_tr_mem, t_te_non_mem = 0, 0 78 | for num in range(self.num_classes): 79 | thre = self._thre_setting(s_tr_values[self.s_tr_labels == num], s_te_values[self.s_te_labels == num]) 80 | t_tr_mem_tmp = np.sum(t_tr_values[self.t_tr_labels == num] >= thre) 81 | t_tr_mem += t_tr_mem_tmp 82 | t_te_non_mem_tmp = np.sum(t_te_values[self.t_te_labels == num] < thre) 83 | t_te_non_mem += t_te_non_mem_tmp 84 | tmp_acc = 0.5 * (t_tr_mem_tmp / (len(t_tr_values[self.t_tr_labels == num]) + 0.0) + 85 | t_te_non_mem_tmp / (len(t_te_values[self.t_te_labels == num]) + 0.0)) 86 | mem_inf_acc = 0.5 * (t_tr_mem / (len(self.t_tr_labels) + 0.0) + t_te_non_mem / (len(self.t_te_labels) + 0.0)) 87 | return mem_inf_acc 88 | 89 | def _mem_inf_thre_non_cls(self, v_name, s_tr_values, s_te_values, t_tr_values, t_te_values): 90 | # perform membership inference attack by thresholding feature values: the feature can be prediction confidence, 91 | # (negative) prediction entropy, and (negative) modified entropy 92 | # t_tr_mem, t_te_non_mem = 0, 0 93 | thre = self._thre_setting(s_tr_values, s_te_values) 94 | t_tr_mem = np.sum(t_tr_values >= thre) 95 | t_te_non_mem = np.sum(t_te_values < thre) 96 | mem_inf_acc = 0.5 * (t_tr_mem / (len(self.t_tr_labels) + 0.0) + t_te_non_mem / (len(self.t_te_labels) + 0.0)) 97 | return mem_inf_acc 98 | 99 | def _mem_inf_benchmarks(self): 100 | confidence = \ 101 | self._mem_inf_thre('confidence', self.s_tr_conf, self.s_te_conf, self.t_tr_conf, self.t_te_conf) 102 | entropy = \ 103 | self._mem_inf_thre('entropy', -self.s_tr_entr, -self.s_te_entr, -self.t_tr_entr, -self.t_te_entr) 104 | modentr = \ 105 | self._mem_inf_thre('modified entropy', -self.s_tr_m_entr, -self.s_te_m_entr, -self.t_tr_m_entr, 106 | -self.t_te_m_entr) 107 | return confidence, entropy, modentr 108 | 109 | def _mem_inf_benchmarks_non_cls(self): 110 | confidence = \ 111 | self._mem_inf_thre_non_cls('confidence', self.s_tr_conf, self.s_te_conf, self.t_tr_conf, self.t_te_conf) 112 | entropy = \ 113 | self._mem_inf_thre_non_cls('entropy', -self.s_tr_entr, -self.s_te_entr, -self.t_tr_entr, -self.t_te_entr) 114 | modentr = \ 115 | self._mem_inf_thre_non_cls('modified entropy', -self.s_tr_m_entr, -self.s_te_m_entr, -self.t_tr_m_entr, 116 | -self.t_te_m_entr) 117 | return confidence, entropy, modentr 118 | -------------------------------------------------------------------------------- /attackers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from attacker_threshold import ThresholdAttacker 4 | from base_model import BaseModel 5 | 6 | from torch.utils.data import DataLoader, TensorDataset 7 | from utils import seed_worker 8 | 9 | 10 | class MiaAttack: 11 | def __init__(self, victim_model, victim_pruned_model, victim_train_loader, victim_test_loader, 12 | shadow_model_list, shadow_pruned_model_list, shadow_train_loader_list, shadow_test_loader_list, 13 | num_cls=10, batch_size=128, device="cuda", 14 | lr=0.001, optimizer="sgd", epochs=100, weight_decay=5e-4, 15 | # lr=0.001, optimizer="adam", epochs=100, weight_decay=5e-4, 16 | attack_original=False 17 | ): 18 | self.victim_model = victim_model 19 | self.victim_pruned_model = victim_pruned_model 20 | self.victim_train_loader = victim_train_loader 21 | self.victim_test_loader = victim_test_loader 22 | self.shadow_model_list = shadow_model_list 23 | self.shadow_pruned_model_list = shadow_pruned_model_list 24 | self.shadow_train_loader_list = shadow_train_loader_list 25 | self.shadow_test_loader_list = shadow_test_loader_list 26 | self.num_cls = num_cls 27 | self.device = device 28 | self.lr = lr 29 | self.optimizer = optimizer 30 | self.weight_decay = weight_decay 31 | self.epochs = epochs 32 | self.batch_size = batch_size 33 | self.attack_original = attack_original 34 | self._prepare() 35 | 36 | def _prepare(self): 37 | attack_in_predicts_list, attack_in_targets_list, attack_in_sens_list = [], [], [] 38 | attack_out_predicts_list, attack_out_targets_list, attack_out_sens_list = [], [], [] 39 | for shadow_model, shadow_pruned_model, shadow_train_loader, shadow_test_loader in zip( 40 | self.shadow_model_list, self.shadow_pruned_model_list, self.shadow_train_loader_list, 41 | self.shadow_test_loader_list): 42 | 43 | if self.attack_original: 44 | attack_in_predicts, attack_in_targets, attack_in_sens = \ 45 | shadow_model.predict_target_sensitivity(shadow_train_loader) 46 | attack_out_predicts, attack_out_targets, attack_out_sens = \ 47 | shadow_model.predict_target_sensitivity(shadow_test_loader) 48 | else: 49 | attack_in_predicts, attack_in_targets, attack_in_sens = \ 50 | shadow_pruned_model.predict_target_sensitivity(shadow_train_loader) 51 | attack_out_predicts, attack_out_targets, attack_out_sens = \ 52 | shadow_pruned_model.predict_target_sensitivity(shadow_test_loader) 53 | 54 | attack_in_predicts_list.append(attack_in_predicts) 55 | attack_in_targets_list.append(attack_in_targets) 56 | attack_in_sens_list.append(attack_in_sens) 57 | attack_out_predicts_list.append(attack_out_predicts) 58 | attack_out_targets_list.append(attack_out_targets) 59 | attack_out_sens_list.append(attack_out_sens) 60 | 61 | self.attack_in_predicts = torch.cat(attack_in_predicts_list, dim=0) 62 | self.attack_in_targets = torch.cat(attack_in_targets_list, dim=0) 63 | self.attack_in_sens = torch.cat(attack_in_sens_list, dim=0) 64 | self.attack_out_predicts = torch.cat(attack_out_predicts_list, dim=0) 65 | self.attack_out_targets = torch.cat(attack_out_targets_list, dim=0) 66 | self.attack_out_sens = torch.cat(attack_out_sens_list, dim=0) 67 | 68 | if self.attack_original: 69 | self.victim_in_predicts, self.victim_in_targets, self.victim_in_sens = \ 70 | self.victim_model.predict_target_sensitivity(self.victim_train_loader) 71 | self.victim_out_predicts, self.victim_out_targets, self.victim_out_sens = \ 72 | self.victim_model.predict_target_sensitivity(self.victim_test_loader) 73 | else: 74 | self.victim_in_predicts, self.victim_in_targets, self.victim_in_sens = \ 75 | self.victim_pruned_model.predict_target_sensitivity(self.victim_train_loader) 76 | self.victim_out_predicts, self.victim_out_targets, self.victim_out_sens = \ 77 | self.victim_pruned_model.predict_target_sensitivity(self.victim_test_loader) 78 | 79 | def nn_attack(self, mia_type="nn_sens_cls", model_name="mia_fc"): 80 | attack_predicts = torch.cat([self.attack_in_predicts, self.attack_out_predicts], dim=0) 81 | attack_sens = torch.cat([self.attack_in_sens, self.attack_out_sens], dim=0) 82 | attack_targets = torch.cat([self.attack_in_targets, self.attack_out_targets], dim=0) 83 | attack_targets = F.one_hot(attack_targets, num_classes=self.num_cls).float() 84 | attack_labels = torch.cat([torch.ones(self.attack_in_targets.size(0)), 85 | torch.zeros(self.attack_out_targets.size(0))], dim=0).long() 86 | 87 | victim_predicts = torch.cat([self.victim_in_predicts, self.victim_out_predicts], dim=0) 88 | victim_sens = torch.cat([self.victim_in_sens, self.victim_out_sens], dim=0) 89 | victim_targets = torch.cat([self.victim_in_targets, self.victim_out_targets], dim=0) 90 | victim_targets = F.one_hot(victim_targets, num_classes=self.num_cls).float() 91 | victim_labels = torch.cat([torch.ones(self.victim_in_targets.size(0)), 92 | torch.zeros(self.victim_out_targets.size(0))], dim=0).long() 93 | 94 | if mia_type == "nn_cls": 95 | new_attack_data = torch.cat([attack_predicts, attack_targets], dim=1) 96 | new_victim_data = torch.cat([victim_predicts, victim_targets], dim=1) 97 | elif mia_type == "nn_top3": 98 | new_attack_data, _ = torch.topk(attack_predicts, k=3, dim=-1) 99 | new_victim_data, _ = torch.topk(victim_predicts, k=3, dim=-1) 100 | elif mia_type == "nn_sens_cls": 101 | new_attack_data = torch.cat([attack_predicts, attack_sens, attack_targets], dim=1) 102 | new_victim_data = torch.cat([victim_predicts, victim_sens, victim_targets], dim=1) 103 | else: 104 | new_attack_data = attack_predicts 105 | new_victim_data = victim_predicts 106 | 107 | attack_train_dataset = TensorDataset(new_attack_data, attack_labels) 108 | attack_train_dataloader = DataLoader( 109 | attack_train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True, 110 | worker_init_fn=seed_worker) 111 | attack_test_dataset = TensorDataset(new_victim_data, victim_labels) 112 | attack_test_dataloader = DataLoader( 113 | attack_test_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True, 114 | worker_init_fn=seed_worker) 115 | 116 | attack_model = BaseModel( 117 | model_name, device=self.device, num_cls=new_victim_data.size(1), optimizer=self.optimizer, lr=self.lr, 118 | weight_decay=self.weight_decay, epochs=self.epochs) 119 | 120 | for epoch in range(self.epochs): 121 | train_acc, train_loss = attack_model.train(attack_train_dataloader) 122 | test_acc, test_loss = attack_model.test(attack_test_dataloader) 123 | return test_acc 124 | 125 | def threshold_attack(self): 126 | victim_in_predicts = self.victim_in_predicts.numpy() 127 | victim_out_predicts = self.victim_out_predicts.numpy() 128 | 129 | attack_in_predicts = self.attack_in_predicts.numpy() 130 | attack_out_predicts = self.attack_out_predicts.numpy() 131 | attacker = ThresholdAttacker((attack_in_predicts, self.attack_in_targets.numpy()), 132 | (attack_out_predicts, self.attack_out_targets.numpy()), 133 | (victim_in_predicts, self.victim_in_targets.numpy()), 134 | (victim_out_predicts, self.victim_out_targets.numpy()), 135 | self.num_cls) 136 | confidence, entropy, modified_entropy = attacker._mem_inf_benchmarks() 137 | top1_conf, _, _ = attacker._mem_inf_benchmarks_non_cls() 138 | return confidence * 100., entropy * 100., modified_entropy * 100., \ 139 | top1_conf * 100. 140 | -------------------------------------------------------------------------------- /base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils import get_model, get_optimizer, weight_init 5 | 6 | 7 | class BaseModel: 8 | def __init__(self, model_type, device="cuda", save_folder="", num_cls=10, 9 | optimizer="", lr=0.01, weight_decay=0, input_dim=100, epochs=0, attack_model_type=''): 10 | self.model = get_model(model_type, num_cls, input_dim) 11 | self.model.to(device) 12 | self.model.apply(weight_init) 13 | self.device = device 14 | self.optimizer = get_optimizer(optimizer, self.model.parameters(), lr, weight_decay) 15 | if epochs == 0: 16 | self.scheduler = None 17 | else: 18 | self.scheduler = torch.optim.lr_scheduler.MultiStepLR( 19 | self.optimizer, milestones=[epochs // 2, epochs * 3 // 4], gamma=0.1) 20 | self.criterion = nn.CrossEntropyLoss() 21 | self.save_pref = save_folder 22 | self.num_cls = num_cls 23 | 24 | if attack_model_type: 25 | self.attack_model = get_model(attack_model_type, num_cls*2, 2) 26 | self.attack_model.to(device) 27 | self.attack_model.apply(weight_init) 28 | self.attack_model_optim = get_optimizer("adam", self.attack_model.parameters(), lr=0.001, weight_decay=5e-4) 29 | 30 | def train(self, train_loader, log_pref=""): 31 | self.model.train() 32 | total_loss = 0 33 | correct = 0 34 | total = 0 35 | for inputs, targets in train_loader: 36 | inputs, targets = inputs.to(self.device), targets.to(self.device) 37 | self.optimizer.zero_grad() 38 | outputs = self.model(inputs) 39 | loss = self.criterion(outputs, targets) 40 | loss.backward() 41 | self.optimizer.step() 42 | total_loss += loss.item() * targets.size(0) 43 | total += targets.size(0) 44 | _, predicted = outputs.max(1) 45 | correct += predicted.eq(targets).sum().item() 46 | if self.scheduler: 47 | self.scheduler.step() 48 | acc = 100. * correct / total 49 | total_loss /= total 50 | if log_pref: 51 | print("{}: Accuracy {:.3f}, Loss {:.3f}".format(log_pref, acc, total_loss)) 52 | return acc, total_loss 53 | 54 | def train_defend_ppb(self, train_loader, log_pref="",defend_arg=None): 55 | self.model.train() 56 | total_loss = 0 57 | total_loss1 = 0 58 | total_loss2 = 0 59 | correct = 0 60 | total = 0 61 | for inputs, targets in train_loader: 62 | self.optimizer.zero_grad() 63 | inputs, targets = inputs.to(self.device), targets.to(self.device) 64 | outputs = self.model(inputs) 65 | loss1 = self.criterion(outputs, targets) 66 | ranked_outputs, _ = torch.topk(outputs, self.num_cls, dim=-1) 67 | size = targets.size(0) 68 | even_size = size // 2 * 2 69 | if even_size > 0: 70 | loss2 = F.kl_div(F.log_softmax(ranked_outputs[:even_size // 2], dim=-1), 71 | F.softmax(ranked_outputs[even_size // 2:even_size], dim=-1), 72 | reduction='batchmean') 73 | else: 74 | loss2 = torch.zeros(1).to(self.device) 75 | loss = loss1 + defend_arg * loss2 76 | total_loss += loss.item() * size 77 | total_loss1 += loss1.item() * size 78 | total_loss2 += loss2.item() * size 79 | total += size 80 | _, predicted = outputs.max(1) 81 | correct += predicted.eq(targets).sum().item() 82 | loss.backward() 83 | self.optimizer.step() 84 | acc = 100. * correct / total 85 | total_loss /= total 86 | total_loss1 /= total 87 | total_loss2 /= total 88 | 89 | if self.scheduler: 90 | self.scheduler.step() 91 | if log_pref: 92 | print("{}: Accuracy {:.3f}, Loss {:.3f}, Loss1 {:.3f}, Loss2 {:.3f}".format( 93 | log_pref, acc, total_loss, total_loss1, total_loss2)) 94 | return acc, total_loss 95 | 96 | def train_defend_adv(self, train_loader, test_loader, log_pref="", privacy_theta=1.0): 97 | """ 98 | modified from 99 | https://github.com/Lab41/cyphercat/blob/master/Defenses/Adversarial_Regularization.ipynb 100 | """ 101 | total_loss = 0 102 | correct = 0 103 | total = 0 104 | infer_iterations = 7 105 | # train adversarial network 106 | 107 | train_iter = iter(train_loader) 108 | test_iter = iter(test_loader) 109 | train_iter2 = iter(train_loader) 110 | 111 | self.model.eval() 112 | self.attack_model.train() 113 | for infer_iter in range(infer_iterations): 114 | with torch.no_grad(): 115 | try: 116 | inputs, targets = next(train_iter) 117 | except StopIteration: 118 | train_iter = iter(train_loader) 119 | inputs, targets = next(train_iter) 120 | inputs, targets = inputs.to(self.device), targets.to(self.device) 121 | in_predicts = F.softmax(self.model(inputs), dim=-1) 122 | in_targets = F.one_hot(targets, num_classes=self.num_cls).float() 123 | 124 | try: 125 | inputs, targets = next(test_iter) 126 | except StopIteration: 127 | test_iter = iter(test_loader) 128 | inputs, targets = next(test_iter) 129 | inputs, targets = inputs.to(self.device), targets.to(self.device) 130 | out_predicts = F.softmax(self.model(inputs), dim=-1) 131 | out_targets = F.one_hot(targets, num_classes=self.num_cls).float() 132 | 133 | infer_train_data = torch.cat([torch.cat([in_predicts, in_targets], dim=-1), 134 | torch.cat([out_predicts, out_targets], dim=-1)], dim=0) 135 | infer_train_label = torch.cat([torch.ones(in_predicts.size(0)), 136 | torch.zeros(out_predicts.size(0))]).long().to(self.device) 137 | 138 | self.attack_model_optim.zero_grad() 139 | infer_loss = privacy_theta * F.cross_entropy(self.attack_model(infer_train_data), infer_train_label) 140 | infer_loss.backward() 141 | self.attack_model_optim.step() 142 | 143 | self.model.train() 144 | self.attack_model.eval() 145 | try: 146 | inputs, targets = next(train_iter2) 147 | except StopIteration: 148 | train_iter2 = iter(train_loader) 149 | inputs, targets = next(train_iter2) 150 | inputs, targets = inputs.to(self.device), targets.to(self.device) 151 | self.optimizer.zero_grad() 152 | outputs = self.model(inputs) 153 | loss1 = self.criterion(outputs, targets) 154 | in_predicts = F.softmax(outputs, dim=-1) 155 | in_targets = F.one_hot(targets, num_classes=self.num_cls).float() 156 | infer_data = torch.cat([in_predicts, in_targets], dim=-1) 157 | infer_labels = torch.ones(targets.size(0)).long().to(self.device) 158 | infer_loss = F.cross_entropy(self.attack_model(infer_data), infer_labels) 159 | loss = loss1 - privacy_theta * infer_loss 160 | loss.backward() 161 | self.optimizer.step() 162 | total_loss += loss.item() * targets.size(0) 163 | total += targets.size(0) 164 | _, predicted = outputs.max(1) 165 | correct += predicted.eq(targets).sum().item() 166 | if self.scheduler: 167 | self.scheduler.step() 168 | acc = 100. * correct / total 169 | total_loss /= total 170 | if log_pref: 171 | print("{}: Accuracy {:.3f}, Loss {:.3f}".format(log_pref, acc, total_loss)) 172 | return acc, total_loss 173 | 174 | def test(self, test_loader, log_pref=""): 175 | self.model.eval() 176 | total_loss = 0 177 | correct = 0 178 | total = 0 179 | with torch.no_grad(): 180 | for batch_idx, (inputs, targets) in enumerate(test_loader): 181 | inputs, targets = inputs.to(self.device), targets.to(self.device) 182 | outputs = self.model(inputs) 183 | loss = self.criterion(outputs, targets) 184 | total_loss += loss.item() * targets.size(0) 185 | if isinstance(self.criterion, nn.BCELoss): 186 | correct += torch.sum(torch.round(outputs) == targets) 187 | else: 188 | _, predicted = outputs.max(1) 189 | correct += predicted.eq(targets).sum().item() 190 | total += targets.size(0) 191 | 192 | acc = 100. * correct / total 193 | total_loss /= total 194 | if log_pref: 195 | print("{}: Accuracy {:.3f}, Loss {:.3f}".format(log_pref, acc, total_loss)) 196 | return acc, total_loss 197 | 198 | def save(self, epoch, acc, loss): 199 | save_path = f"{self.save_pref}/{epoch}.pth" 200 | state = { 201 | 'epoch': epoch + 1, 202 | 'acc': acc, 203 | 'loss': loss, 204 | 'state': self.model.state_dict() 205 | } 206 | torch.save(state, save_path) 207 | return save_path 208 | 209 | def load(self, load_path, verbose=False): 210 | state = torch.load(load_path, map_location=self.device) 211 | acc = state['acc'] 212 | if verbose: 213 | print(f"Load model from {load_path}") 214 | print(f"Epoch {state['epoch']}, Acc: {state['acc']:.3f}, Loss: {state['loss']:.3f}") 215 | self.model.load_state_dict(state['state']) 216 | return acc 217 | 218 | def predict_target_sensitivity(self, data_loader, m=10, epsilon=1e-3): 219 | self.model.eval() 220 | predict_list = [] 221 | sensitivity_list = [] 222 | target_list = [] 223 | with torch.no_grad(): 224 | for inputs, targets in data_loader: 225 | inputs = inputs.to(self.device) 226 | outputs = self.model(inputs) 227 | predicts = F.softmax(outputs, dim=-1) 228 | predict_list.append(predicts.detach().data.cpu()) 229 | target_list.append(targets) 230 | 231 | if len(inputs.size()) == 4: 232 | x = inputs.repeat((m, 1, 1, 1)) 233 | elif len(inputs.size()) == 3: 234 | x = inputs.repeat((m, 1, 1)) 235 | elif len(inputs.size()) == 2: 236 | x = inputs.repeat((m, 1)) 237 | u = torch.randn_like(x) 238 | evaluation_points = x + epsilon * u 239 | new_predicts = F.softmax(self.model(evaluation_points), dim=-1) 240 | diff = torch.abs(new_predicts - predicts.repeat((m, 1))) 241 | diff = diff.view(m, -1, self.num_cls) 242 | sensitivity = diff.mean(dim=0) / epsilon 243 | sensitivity_list.append(sensitivity.detach().data.cpu()) 244 | 245 | targets = torch.cat(target_list, dim=0) 246 | predicts = torch.cat(predict_list, dim=0) 247 | sensitivities = torch.cat(sensitivity_list, dim=0) 248 | return predicts, targets, sensitivities 249 | -------------------------------------------------------------------------------- /cifar10_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mobilenetv2 import * 2 | from .resnet import * 3 | from .vgg import * 4 | from .densenet import * 5 | from .resnet_orig import * 6 | from .googlenet import * 7 | from .inception import * 8 | 9 | # This folder is modified from "https://github.com/huyvnphan/PyTorch_CIFAR10" -------------------------------------------------------------------------------- /cifar10_models/densenet.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from collections import OrderedDict 6 | import os 7 | 8 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 9 | 10 | class _DenseLayer(nn.Sequential): 11 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 12 | super(_DenseLayer, self).__init__() 13 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 14 | self.add_module('relu1', nn.ReLU(inplace=True)), 15 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 16 | growth_rate, kernel_size=1, stride=1, 17 | bias=False)), 18 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 19 | self.add_module('relu2', nn.ReLU(inplace=True)), 20 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 21 | kernel_size=3, stride=1, padding=1, 22 | bias=False)), 23 | self.drop_rate = drop_rate 24 | 25 | def forward(self, x): 26 | new_features = super(_DenseLayer, self).forward(x) 27 | if self.drop_rate > 0: 28 | new_features = F.dropout(new_features, p=self.drop_rate, 29 | training=self.training) 30 | return torch.cat([x, new_features], 1) 31 | 32 | 33 | class _DenseBlock(nn.Sequential): 34 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 35 | super(_DenseBlock, self).__init__() 36 | for i in range(num_layers): 37 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, 38 | bn_size, drop_rate) 39 | self.add_module('denselayer%d' % (i + 1), layer) 40 | 41 | 42 | class _Transition(nn.Sequential): 43 | def __init__(self, num_input_features, num_output_features): 44 | super(_Transition, self).__init__() 45 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 46 | self.add_module('relu', nn.ReLU(inplace=True)) 47 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 48 | kernel_size=1, stride=1, bias=False)) 49 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 50 | 51 | 52 | class DenseNet(nn.Module): 53 | r"""Densenet-BC model class, based on 54 | `"Densely Connected Convolutional Networks" `_ 55 | 56 | Args: 57 | growth_rate (int) - how many filters to add each layer (`k` in paper) 58 | block_config (list of 4 ints) - how many layers in each pooling block 59 | num_init_features (int) - the number of filters to learn in the first convolution layer 60 | bn_size (int) - multiplicative factor for number of bottle neck layers 61 | (i.e. bn_size * k features in the bottleneck layer) 62 | drop_rate (float) - dropout rate after each dense layer 63 | num_classes (int) - number of classification classes 64 | """ 65 | 66 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 67 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=10): 68 | 69 | super(DenseNet, self).__init__() 70 | 71 | # First convolution 72 | 73 | # CIFAR-10: kernel_size 7 ->3, stride 2->1, padding 3->1 74 | self.features = nn.Sequential(OrderedDict([ 75 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=3, stride=1, 76 | padding=1, bias=False)), 77 | ('norm0', nn.BatchNorm2d(num_init_features)), 78 | ('relu0', nn.ReLU(inplace=True)), 79 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 80 | ])) 81 | ## END 82 | 83 | # Each denseblock 84 | num_features = num_init_features 85 | for i, num_layers in enumerate(block_config): 86 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 87 | bn_size=bn_size, growth_rate=growth_rate, 88 | drop_rate=drop_rate) 89 | self.features.add_module('denseblock%d' % (i + 1), block) 90 | num_features = num_features + num_layers * growth_rate 91 | if i != len(block_config) - 1: 92 | trans = _Transition(num_input_features=num_features, 93 | num_output_features=num_features // 2) 94 | self.features.add_module('transition%d' % (i + 1), trans) 95 | num_features = num_features // 2 96 | 97 | # Final batch norm 98 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 99 | 100 | # Linear layer 101 | self.classifier = nn.Linear(num_features, num_classes) 102 | 103 | # Official init from torch repo. 104 | for m in self.modules(): 105 | if isinstance(m, nn.Conv2d): 106 | nn.init.kaiming_normal_(m.weight) 107 | elif isinstance(m, nn.BatchNorm2d): 108 | nn.init.constant_(m.weight, 1) 109 | nn.init.constant_(m.bias, 0) 110 | elif isinstance(m, nn.Linear): 111 | nn.init.constant_(m.bias, 0) 112 | 113 | def forward(self, x): 114 | features = self.features(x) 115 | out = F.relu(features, inplace=True) 116 | out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1) 117 | out = self.classifier(out) 118 | return out 119 | 120 | def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress, device, **kwargs): 121 | model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) 122 | if pretrained: 123 | script_dir = os.path.dirname(__file__) 124 | state_dict = torch.load(script_dir + '/state_dicts/'+arch+'.pt', map_location=device) 125 | model.load_state_dict(state_dict) 126 | return model 127 | 128 | 129 | def densenet121(pretrained=False, progress=True, device='cpu', **kwargs): 130 | r"""Densenet-121 model from 131 | `"Densely Connected Convolutional Networks" `_ 132 | 133 | Args: 134 | pretrained (bool): If True, returns a model pre-trained on ImageNet 135 | progress (bool): If True, displays a progress bar of the download to stderr 136 | """ 137 | return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress, device, 138 | **kwargs) 139 | 140 | 141 | def densenet161(pretrained=False, progress=True, device='cpu', **kwargs): 142 | r"""Densenet-161 model from 143 | `"Densely Connected Convolutional Networks" `_ 144 | 145 | Args: 146 | pretrained (bool): If True, returns a model pre-trained on ImageNet 147 | progress (bool): If True, displays a progress bar of the download to stderr 148 | """ 149 | return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress, device, 150 | **kwargs) 151 | 152 | 153 | def densenet169(pretrained=False, progress=True, device='cpu', **kwargs): 154 | r"""Densenet-169 model from 155 | `"Densely Connected Convolutional Networks" `_ 156 | 157 | Args: 158 | pretrained (bool): If True, returns a model pre-trained on ImageNet 159 | progress (bool): If True, displays a progress bar of the download to stderr 160 | """ 161 | return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress, device, 162 | **kwargs) 163 | 164 | 165 | def densenet201(pretrained=False, progress=True, device='cpu', **kwargs): 166 | r"""Densenet-201 model from 167 | `"Densely Connected Convolutional Networks" `_ 168 | 169 | Args: 170 | pretrained (bool): If True, returns a model pre-trained on ImageNet 171 | progress (bool): If True, displays a progress bar of the download to stderr 172 | """ 173 | return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress, device, 174 | **kwargs) 175 | -------------------------------------------------------------------------------- /cifar10_models/googlenet.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import namedtuple 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import os 7 | 8 | __all__ = ['GoogLeNet', 'googlenet'] 9 | 10 | 11 | _GoogLeNetOuputs = namedtuple('GoogLeNetOuputs', ['logits', 'aux_logits2', 'aux_logits1']) 12 | 13 | 14 | def googlenet(pretrained=False, progress=True, device='cpu', **kwargs): 15 | r"""GoogLeNet (Inception v1) model architecture from 16 | `"Going Deeper with Convolutions" `_. 17 | 18 | Args: 19 | pretrained (bool): If True, returns a model pre-trained on ImageNet 20 | progress (bool): If True, displays a progress bar of the download to stderr 21 | aux_logits (bool): If True, adds two auxiliary branches that can improve training. 22 | Default: *False* when pretrained is True otherwise *True* 23 | transform_input (bool): If True, preprocesses the input according to the method with which it 24 | was trained on ImageNet. Default: *False* 25 | """ 26 | model = GoogLeNet() 27 | if pretrained: 28 | script_dir = os.path.dirname(__file__) 29 | state_dict = torch.load(script_dir + '/state_dicts/googlenet.pt', map_location=device) 30 | model.load_state_dict(state_dict) 31 | return model 32 | 33 | 34 | class GoogLeNet(nn.Module): 35 | 36 | ## CIFAR10: aux_logits True->False 37 | def __init__(self, num_classes=10, aux_logits=False, transform_input=False): 38 | super(GoogLeNet, self).__init__() 39 | self.aux_logits = aux_logits 40 | self.transform_input = transform_input 41 | 42 | ## CIFAR10: out_channels 64->192, kernel_size 7->3, stride 2->1, padding 3->1 43 | self.conv1 = BasicConv2d(3, 192, kernel_size=3, stride=1, padding=1) 44 | # self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 45 | # self.conv2 = BasicConv2d(64, 64, kernel_size=1) 46 | # self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1) 47 | # self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True) 48 | ## END 49 | 50 | self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32) 51 | self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64) 52 | 53 | ## CIFAR10: padding 0->1, ciel_model True->False 54 | self.maxpool3 = nn.MaxPool2d(3, stride=2, padding=1, ceil_mode=False) 55 | ## END 56 | 57 | self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64) 58 | self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64) 59 | self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64) 60 | self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64) 61 | self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128) 62 | 63 | ## CIFAR10: kernel_size 2->3, padding 0->1, ciel_model True->False 64 | self.maxpool4 = nn.MaxPool2d(3, stride=2, padding=1, ceil_mode=False) 65 | ## END 66 | 67 | self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128) 68 | self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128) 69 | 70 | if aux_logits: 71 | self.aux1 = InceptionAux(512, num_classes) 72 | self.aux2 = InceptionAux(528, num_classes) 73 | 74 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 75 | self.dropout = nn.Dropout(0.2) 76 | self.fc = nn.Linear(1024, num_classes) 77 | 78 | # if init_weights: 79 | # self._initialize_weights() 80 | 81 | # def _initialize_weights(self): 82 | # for m in self.modules(): 83 | # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 84 | # import scipy.stats as stats 85 | # X = stats.truncnorm(-2, 2, scale=0.01) 86 | # values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) 87 | # values = values.view(m.weight.size()) 88 | # with torch.no_grad(): 89 | # m.weight.copy_(values) 90 | # elif isinstance(m, nn.BatchNorm2d): 91 | # nn.init.constant_(m.weight, 1) 92 | # nn.init.constant_(m.bias, 0) 93 | 94 | def forward(self, x): 95 | if self.transform_input: 96 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 97 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 98 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 99 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 100 | 101 | # N x 3 x 224 x 224 102 | x = self.conv1(x) 103 | 104 | ## CIFAR10 105 | # N x 64 x 112 x 112 106 | # x = self.maxpool1(x) 107 | # N x 64 x 56 x 56 108 | # x = self.conv2(x) 109 | # N x 64 x 56 x 56 110 | # x = self.conv3(x) 111 | # N x 192 x 56 x 56 112 | # x = self.maxpool2(x) 113 | ## END 114 | 115 | # N x 192 x 28 x 28 116 | x = self.inception3a(x) 117 | # N x 256 x 28 x 28 118 | x = self.inception3b(x) 119 | # N x 480 x 28 x 28 120 | x = self.maxpool3(x) 121 | # N x 480 x 14 x 14 122 | x = self.inception4a(x) 123 | # N x 512 x 14 x 14 124 | if self.training and self.aux_logits: 125 | aux1 = self.aux1(x) 126 | 127 | x = self.inception4b(x) 128 | # N x 512 x 14 x 14 129 | x = self.inception4c(x) 130 | # N x 512 x 14 x 14 131 | x = self.inception4d(x) 132 | # N x 528 x 14 x 14 133 | if self.training and self.aux_logits: 134 | aux2 = self.aux2(x) 135 | 136 | x = self.inception4e(x) 137 | # N x 832 x 14 x 14 138 | x = self.maxpool4(x) 139 | # N x 832 x 7 x 7 140 | x = self.inception5a(x) 141 | # N x 832 x 7 x 7 142 | x = self.inception5b(x) 143 | # N x 1024 x 7 x 7 144 | 145 | x = self.avgpool(x) 146 | # N x 1024 x 1 x 1 147 | x = x.view(x.size(0), -1) 148 | # N x 1024 149 | x = self.dropout(x) 150 | x = self.fc(x) 151 | # N x 1000 (num_classes) 152 | if self.training and self.aux_logits: 153 | return _GoogLeNetOuputs(x, aux2, aux1) 154 | return x 155 | 156 | 157 | class Inception(nn.Module): 158 | 159 | def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj): 160 | super(Inception, self).__init__() 161 | 162 | self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1) 163 | 164 | self.branch2 = nn.Sequential( 165 | BasicConv2d(in_channels, ch3x3red, kernel_size=1), 166 | BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1) 167 | ) 168 | 169 | self.branch3 = nn.Sequential( 170 | BasicConv2d(in_channels, ch5x5red, kernel_size=1), 171 | BasicConv2d(ch5x5red, ch5x5, kernel_size=3, padding=1) 172 | ) 173 | 174 | self.branch4 = nn.Sequential( 175 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True), 176 | BasicConv2d(in_channels, pool_proj, kernel_size=1) 177 | ) 178 | 179 | def forward(self, x): 180 | branch1 = self.branch1(x) 181 | branch2 = self.branch2(x) 182 | branch3 = self.branch3(x) 183 | branch4 = self.branch4(x) 184 | 185 | outputs = [branch1, branch2, branch3, branch4] 186 | return torch.cat(outputs, 1) 187 | 188 | 189 | class InceptionAux(nn.Module): 190 | 191 | def __init__(self, in_channels, num_classes): 192 | super(InceptionAux, self).__init__() 193 | self.conv = BasicConv2d(in_channels, 128, kernel_size=1) 194 | 195 | self.fc1 = nn.Linear(2048, 1024) 196 | self.fc2 = nn.Linear(1024, num_classes) 197 | 198 | def forward(self, x): 199 | # aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14 200 | x = F.adaptive_avg_pool2d(x, (4, 4)) 201 | # aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4 202 | x = self.conv(x) 203 | # N x 128 x 4 x 4 204 | x = x.view(x.size(0), -1) 205 | # N x 2048 206 | x = F.relu(self.fc1(x), inplace=True) 207 | # N x 2048 208 | x = F.dropout(x, 0.7, training=self.training) 209 | # N x 2048 210 | x = self.fc2(x) 211 | # N x 1024 212 | 213 | return x 214 | 215 | 216 | class BasicConv2d(nn.Module): 217 | 218 | def __init__(self, in_channels, out_channels, **kwargs): 219 | super(BasicConv2d, self).__init__() 220 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 221 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 222 | 223 | def forward(self, x): 224 | x = self.conv(x) 225 | x = self.bn(x) 226 | return F.relu(x, inplace=True) 227 | -------------------------------------------------------------------------------- /cifar10_models/inception.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import os 6 | 7 | __all__ = ['Inception3', 'inception_v3'] 8 | 9 | 10 | _InceptionOuputs = namedtuple('InceptionOuputs', ['logits', 'aux_logits']) 11 | 12 | 13 | def inception_v3(pretrained=False, progress=True, device='cpu', **kwargs): 14 | r"""Inception v3 model architecture from 15 | `"Rethinking the Inception Architecture for Computer Vision" `_. 16 | 17 | .. note:: 18 | **Important**: In contrast to the other models the inception_v3 expects tensors with a size of 19 | N x 3 x 299 x 299, so ensure your images are sized accordingly. 20 | 21 | Args: 22 | pretrained (bool): If True, returns a model pre-trained on ImageNet 23 | progress (bool): If True, displays a progress bar of the download to stderr 24 | aux_logits (bool): If True, add an auxiliary branch that can improve training. 25 | Default: *True* 26 | transform_input (bool): If True, preprocesses the input according to the method with which it 27 | was trained on ImageNet. Default: *False* 28 | """ 29 | model = Inception3() 30 | if pretrained: 31 | script_dir = os.path.dirname(__file__) 32 | state_dict = torch.load(script_dir + '/state_dicts/inception_v3.pt', map_location=device) 33 | model.load_state_dict(state_dict) 34 | return model 35 | 36 | class Inception3(nn.Module): 37 | ## CIFAR10: aux_logits True->False 38 | def __init__(self, num_classes=10, aux_logits=False, transform_input=False): 39 | super(Inception3, self).__init__() 40 | self.aux_logits = aux_logits 41 | self.transform_input = transform_input 42 | 43 | ## CIFAR10: stride 2->1, padding 0 -> 1 44 | self.Conv2d_1a_3x3 = BasicConv2d(3, 192, kernel_size=3, stride=1, padding=1) 45 | # self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) 46 | # self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) 47 | # self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) 48 | # self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) 49 | self.Mixed_5b = InceptionA(192, pool_features=32) 50 | self.Mixed_5c = InceptionA(256, pool_features=64) 51 | self.Mixed_5d = InceptionA(288, pool_features=64) 52 | self.Mixed_6a = InceptionB(288) 53 | self.Mixed_6b = InceptionC(768, channels_7x7=128) 54 | self.Mixed_6c = InceptionC(768, channels_7x7=160) 55 | self.Mixed_6d = InceptionC(768, channels_7x7=160) 56 | self.Mixed_6e = InceptionC(768, channels_7x7=192) 57 | if aux_logits: 58 | self.AuxLogits = InceptionAux(768, num_classes) 59 | self.Mixed_7a = InceptionD(768) 60 | self.Mixed_7b = InceptionE(1280) 61 | self.Mixed_7c = InceptionE(2048) 62 | self.fc = nn.Linear(2048, num_classes) 63 | 64 | # for m in self.modules(): 65 | # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 66 | # import scipy.stats as stats 67 | # stddev = m.stddev if hasattr(m, 'stddev') else 0.1 68 | # X = stats.truncnorm(-2, 2, scale=stddev) 69 | # values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype) 70 | # values = values.view(m.weight.size()) 71 | # with torch.no_grad(): 72 | # m.weight.copy_(values) 73 | # elif isinstance(m, nn.BatchNorm2d): 74 | # nn.init.constant_(m.weight, 1) 75 | # nn.init.constant_(m.bias, 0) 76 | 77 | def forward(self, x): 78 | if self.transform_input: 79 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 80 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 81 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 82 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 83 | # N x 3 x 299 x 299 84 | x = self.Conv2d_1a_3x3(x) 85 | 86 | ## CIFAR10 87 | # N x 32 x 149 x 149 88 | # x = self.Conv2d_2a_3x3(x) 89 | # N x 32 x 147 x 147 90 | # x = self.Conv2d_2b_3x3(x) 91 | # N x 64 x 147 x 147 92 | # x = F.max_pool2d(x, kernel_size=3, stride=2) 93 | # N x 64 x 73 x 73 94 | # x = self.Conv2d_3b_1x1(x) 95 | # N x 80 x 73 x 73 96 | # x = self.Conv2d_4a_3x3(x) 97 | # N x 192 x 71 x 71 98 | # x = F.max_pool2d(x, kernel_size=3, stride=2) 99 | # N x 192 x 35 x 35 100 | x = self.Mixed_5b(x) 101 | # N x 256 x 35 x 35 102 | x = self.Mixed_5c(x) 103 | # N x 288 x 35 x 35 104 | x = self.Mixed_5d(x) 105 | # N x 288 x 35 x 35 106 | x = self.Mixed_6a(x) 107 | # N x 768 x 17 x 17 108 | x = self.Mixed_6b(x) 109 | # N x 768 x 17 x 17 110 | x = self.Mixed_6c(x) 111 | # N x 768 x 17 x 17 112 | x = self.Mixed_6d(x) 113 | # N x 768 x 17 x 17 114 | x = self.Mixed_6e(x) 115 | # N x 768 x 17 x 17 116 | if self.training and self.aux_logits: 117 | aux = self.AuxLogits(x) 118 | # N x 768 x 17 x 17 119 | x = self.Mixed_7a(x) 120 | # N x 1280 x 8 x 8 121 | x = self.Mixed_7b(x) 122 | # N x 2048 x 8 x 8 123 | x = self.Mixed_7c(x) 124 | # N x 2048 x 8 x 8 125 | # Adaptive average pooling 126 | x = F.adaptive_avg_pool2d(x, (1, 1)) 127 | # N x 2048 x 1 x 1 128 | x = F.dropout(x, training=self.training) 129 | # N x 2048 x 1 x 1 130 | x = x.view(x.size(0), -1) 131 | # N x 2048 132 | x = self.fc(x) 133 | # N x 1000 (num_classes) 134 | if self.training and self.aux_logits: 135 | return _InceptionOuputs(x, aux) 136 | return x 137 | 138 | def forward_last(self, x): 139 | if self.transform_input: 140 | x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 141 | x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 142 | x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 143 | x = torch.cat((x_ch0, x_ch1, x_ch2), 1) 144 | # N x 3 x 299 x 299 145 | x = self.Conv2d_1a_3x3(x) 146 | 147 | ## CIFAR10 148 | # N x 32 x 149 x 149 149 | # x = self.Conv2d_2a_3x3(x) 150 | # N x 32 x 147 x 147 151 | # x = self.Conv2d_2b_3x3(x) 152 | # N x 64 x 147 x 147 153 | # x = F.max_pool2d(x, kernel_size=3, stride=2) 154 | # N x 64 x 73 x 73 155 | # x = self.Conv2d_3b_1x1(x) 156 | # N x 80 x 73 x 73 157 | # x = self.Conv2d_4a_3x3(x) 158 | # N x 192 x 71 x 71 159 | # x = F.max_pool2d(x, kernel_size=3, stride=2) 160 | # N x 192 x 35 x 35 161 | x = self.Mixed_5b(x) 162 | # N x 256 x 35 x 35 163 | x = self.Mixed_5c(x) 164 | # N x 288 x 35 x 35 165 | x = self.Mixed_5d(x) 166 | # N x 288 x 35 x 35 167 | x = self.Mixed_6a(x) 168 | # N x 768 x 17 x 17 169 | x = self.Mixed_6b(x) 170 | # N x 768 x 17 x 17 171 | x = self.Mixed_6c(x) 172 | # N x 768 x 17 x 17 173 | x = self.Mixed_6d(x) 174 | # N x 768 x 17 x 17 175 | x = self.Mixed_6e(x) 176 | # N x 768 x 17 x 17 177 | if self.training and self.aux_logits: 178 | aux = self.AuxLogits(x) 179 | # N x 768 x 17 x 17 180 | x = self.Mixed_7a(x) 181 | # N x 1280 x 8 x 8 182 | x = self.Mixed_7b(x) 183 | # N x 2048 x 8 x 8 184 | x = self.Mixed_7c(x) 185 | # N x 2048 x 8 x 8 186 | # Adaptive average pooling 187 | x = F.adaptive_avg_pool2d(x, (1, 1)) 188 | # N x 2048 x 1 x 1 189 | x = F.dropout(x, training=self.training) 190 | # N x 2048 x 1 x 1 191 | last = x.view(x.size(0), -1) 192 | # N x 2048 193 | output = self.fc(last) 194 | # N x 1000 (num_classes) 195 | if self.training and self.aux_logits: 196 | return _InceptionOuputs(x, aux) 197 | return output, last 198 | 199 | class InceptionA(nn.Module): 200 | 201 | def __init__(self, in_channels, pool_features): 202 | super(InceptionA, self).__init__() 203 | self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1) 204 | 205 | self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1) 206 | self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2) 207 | 208 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 209 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 210 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1) 211 | 212 | self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1) 213 | 214 | def forward(self, x): 215 | branch1x1 = self.branch1x1(x) 216 | 217 | branch5x5 = self.branch5x5_1(x) 218 | branch5x5 = self.branch5x5_2(branch5x5) 219 | 220 | branch3x3dbl = self.branch3x3dbl_1(x) 221 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 222 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 223 | 224 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 225 | branch_pool = self.branch_pool(branch_pool) 226 | 227 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 228 | return torch.cat(outputs, 1) 229 | 230 | 231 | class InceptionB(nn.Module): 232 | 233 | def __init__(self, in_channels): 234 | super(InceptionB, self).__init__() 235 | self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2) 236 | 237 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) 238 | self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) 239 | self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2) 240 | 241 | def forward(self, x): 242 | branch3x3 = self.branch3x3(x) 243 | 244 | branch3x3dbl = self.branch3x3dbl_1(x) 245 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 246 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 247 | 248 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 249 | 250 | outputs = [branch3x3, branch3x3dbl, branch_pool] 251 | return torch.cat(outputs, 1) 252 | 253 | 254 | class InceptionC(nn.Module): 255 | 256 | def __init__(self, in_channels, channels_7x7): 257 | super(InceptionC, self).__init__() 258 | self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1) 259 | 260 | c7 = channels_7x7 261 | self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1) 262 | self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 263 | self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0)) 264 | 265 | self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1) 266 | self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 267 | self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) 268 | self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) 269 | self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) 270 | 271 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 272 | 273 | def forward(self, x): 274 | branch1x1 = self.branch1x1(x) 275 | 276 | branch7x7 = self.branch7x7_1(x) 277 | branch7x7 = self.branch7x7_2(branch7x7) 278 | branch7x7 = self.branch7x7_3(branch7x7) 279 | 280 | branch7x7dbl = self.branch7x7dbl_1(x) 281 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 282 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 283 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 284 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 285 | 286 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 287 | branch_pool = self.branch_pool(branch_pool) 288 | 289 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 290 | return torch.cat(outputs, 1) 291 | 292 | 293 | class InceptionD(nn.Module): 294 | 295 | def __init__(self, in_channels): 296 | super(InceptionD, self).__init__() 297 | self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 298 | self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2) 299 | 300 | self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) 301 | self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)) 302 | self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)) 303 | self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2) 304 | 305 | def forward(self, x): 306 | branch3x3 = self.branch3x3_1(x) 307 | branch3x3 = self.branch3x3_2(branch3x3) 308 | 309 | branch7x7x3 = self.branch7x7x3_1(x) 310 | branch7x7x3 = self.branch7x7x3_2(branch7x7x3) 311 | branch7x7x3 = self.branch7x7x3_3(branch7x7x3) 312 | branch7x7x3 = self.branch7x7x3_4(branch7x7x3) 313 | 314 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) 315 | outputs = [branch3x3, branch7x7x3, branch_pool] 316 | return torch.cat(outputs, 1) 317 | 318 | 319 | class InceptionE(nn.Module): 320 | 321 | def __init__(self, in_channels): 322 | super(InceptionE, self).__init__() 323 | self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1) 324 | 325 | self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1) 326 | self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 327 | self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 328 | 329 | self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1) 330 | self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) 331 | self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) 332 | self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) 333 | 334 | self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) 335 | 336 | def forward(self, x): 337 | branch1x1 = self.branch1x1(x) 338 | 339 | branch3x3 = self.branch3x3_1(x) 340 | branch3x3 = [ 341 | self.branch3x3_2a(branch3x3), 342 | self.branch3x3_2b(branch3x3), 343 | ] 344 | branch3x3 = torch.cat(branch3x3, 1) 345 | 346 | branch3x3dbl = self.branch3x3dbl_1(x) 347 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 348 | branch3x3dbl = [ 349 | self.branch3x3dbl_3a(branch3x3dbl), 350 | self.branch3x3dbl_3b(branch3x3dbl), 351 | ] 352 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 353 | 354 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) 355 | branch_pool = self.branch_pool(branch_pool) 356 | 357 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 358 | return torch.cat(outputs, 1) 359 | 360 | 361 | class InceptionAux(nn.Module): 362 | 363 | def __init__(self, in_channels, num_classes): 364 | super(InceptionAux, self).__init__() 365 | self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1) 366 | self.conv1 = BasicConv2d(128, 768, kernel_size=5) 367 | self.conv1.stddev = 0.01 368 | self.fc = nn.Linear(768, num_classes) 369 | self.fc.stddev = 0.001 370 | 371 | def forward(self, x): 372 | # N x 768 x 17 x 17 373 | x = F.avg_pool2d(x, kernel_size=5, stride=3) 374 | # N x 768 x 5 x 5 375 | x = self.conv0(x) 376 | # N x 128 x 5 x 5 377 | x = self.conv1(x) 378 | # N x 768 x 1 x 1 379 | # Adaptive average pooling 380 | x = F.adaptive_avg_pool2d(x, (1, 1)) 381 | # N x 768 x 1 x 1 382 | x = x.view(x.size(0), -1) 383 | # N x 768 384 | x = self.fc(x) 385 | # N x 1000 386 | return x 387 | 388 | 389 | class BasicConv2d(nn.Module): 390 | 391 | def __init__(self, in_channels, out_channels, **kwargs): 392 | super(BasicConv2d, self).__init__() 393 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 394 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 395 | 396 | def forward(self, x): 397 | x = self.conv(x) 398 | x = self.bn(x) 399 | return F.relu(x, inplace=True) 400 | -------------------------------------------------------------------------------- /cifar10_models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | 5 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 6 | 7 | 8 | class ConvBNReLU(nn.Sequential): 9 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 10 | padding = (kernel_size - 1) // 2 11 | super(ConvBNReLU, self).__init__( 12 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 13 | nn.BatchNorm2d(out_planes), 14 | nn.ReLU6(inplace=True) 15 | ) 16 | 17 | 18 | class InvertedResidual(nn.Module): 19 | def __init__(self, inp, oup, stride, expand_ratio): 20 | super(InvertedResidual, self).__init__() 21 | self.stride = stride 22 | assert stride in [1, 2] 23 | 24 | hidden_dim = int(round(inp * expand_ratio)) 25 | self.use_res_connect = self.stride == 1 and inp == oup 26 | 27 | layers = [] 28 | if expand_ratio != 1: 29 | # pw 30 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 31 | layers.extend([ 32 | # dw 33 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 34 | # pw-linear 35 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 36 | nn.BatchNorm2d(oup), 37 | ]) 38 | self.conv = nn.Sequential(*layers) 39 | 40 | def forward(self, x): 41 | if self.use_res_connect: 42 | return x + self.conv(x) 43 | else: 44 | return self.conv(x) 45 | 46 | 47 | class MobileNetV2(nn.Module): 48 | def __init__(self, num_classes=10, width_mult=1.0): 49 | super(MobileNetV2, self).__init__() 50 | block = InvertedResidual 51 | input_channel = 32 52 | last_channel = 1280 53 | 54 | ## CIFAR10 55 | inverted_residual_setting = [ 56 | # t, c, n, s 57 | [1, 16, 1, 1], 58 | [6, 24, 2, 1], # Stride 2 -> 1 for CIFAR-10 59 | [6, 32, 3, 2], 60 | [6, 64, 4, 2], 61 | [6, 96, 3, 1], 62 | [6, 160, 3, 2], 63 | [6, 320, 1, 1], 64 | ] 65 | ## END 66 | 67 | # building first layer 68 | input_channel = int(input_channel * width_mult) 69 | self.last_channel = int(last_channel * max(1.0, width_mult)) 70 | 71 | # CIFAR10: stride 2 -> 1 72 | features = [ConvBNReLU(3, input_channel, stride=1)] 73 | # END 74 | 75 | # building inverted residual blocks 76 | for t, c, n, s in inverted_residual_setting: 77 | output_channel = int(c * width_mult) 78 | for i in range(n): 79 | stride = s if i == 0 else 1 80 | features.append(block(input_channel, output_channel, stride, expand_ratio=t)) 81 | input_channel = output_channel 82 | # building last several layers 83 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 84 | # make it nn.Sequential 85 | self.features = nn.Sequential(*features) 86 | 87 | # building classifier 88 | self.classifier = nn.Sequential( 89 | nn.Dropout(0.2), 90 | nn.Linear(self.last_channel, num_classes), 91 | ) 92 | 93 | # weight initialization 94 | for m in self.modules(): 95 | if isinstance(m, nn.Conv2d): 96 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 97 | if m.bias is not None: 98 | nn.init.zeros_(m.bias) 99 | elif isinstance(m, nn.BatchNorm2d): 100 | nn.init.ones_(m.weight) 101 | nn.init.zeros_(m.bias) 102 | elif isinstance(m, nn.Linear): 103 | nn.init.normal_(m.weight, 0, 0.01) 104 | nn.init.zeros_(m.bias) 105 | 106 | def forward(self, x): 107 | x = self.features(x) 108 | x = x.mean([2, 3]) 109 | x = self.classifier(x) 110 | return x 111 | 112 | 113 | def mobilenet_v2(pretrained=False, progress=True, device='cpu', **kwargs): 114 | """ 115 | Constructs a MobileNetV2 architecture from 116 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 117 | 118 | Args: 119 | pretrained (bool): If True, returns a model pre-trained on ImageNet 120 | progress (bool): If True, displays a progress bar of the download to stderr 121 | """ 122 | model = MobileNetV2(**kwargs) 123 | if pretrained: 124 | script_dir = os.path.dirname(__file__) 125 | state_dict = torch.load(script_dir+'/state_dicts/mobilenet_v2.pt', map_location=device) 126 | model.load_state_dict(state_dict) 127 | return model 128 | -------------------------------------------------------------------------------- /cifar10_models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 7 | 8 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=dilation, groups=groups, bias=False, dilation=dilation) 12 | 13 | 14 | def conv1x1(in_planes, out_planes, stride=1): 15 | """1x1 convolution""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 23 | base_width=64, dilation=1, norm_layer=None): 24 | super(BasicBlock, self).__init__() 25 | if norm_layer is None: 26 | norm_layer = nn.BatchNorm2d 27 | if groups != 1 or base_width != 64: 28 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 29 | if dilation > 1: 30 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 31 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 32 | self.conv1 = conv3x3(inplanes, planes, stride) 33 | self.bn1 = norm_layer(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes) 36 | self.bn2 = norm_layer(planes) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | identity = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.relu(out) 55 | 56 | return out 57 | 58 | 59 | class Bottleneck(nn.Module): 60 | expansion = 4 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 63 | base_width=64, dilation=1, norm_layer=None): 64 | super(Bottleneck, self).__init__() 65 | if norm_layer is None: 66 | norm_layer = nn.BatchNorm2d 67 | width = int(planes * (base_width / 64.)) * groups 68 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 69 | self.conv1 = conv1x1(inplanes, width) 70 | self.bn1 = norm_layer(width) 71 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 72 | self.bn2 = norm_layer(width) 73 | self.conv3 = conv1x1(width, planes * self.expansion) 74 | self.bn3 = norm_layer(planes * self.expansion) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.downsample = downsample 77 | self.stride = stride 78 | 79 | def forward(self, x): 80 | identity = x 81 | 82 | out = self.conv1(x) 83 | out = self.bn1(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv2(out) 87 | out = self.bn2(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv3(out) 91 | out = self.bn3(out) 92 | 93 | if self.downsample is not None: 94 | identity = self.downsample(x) 95 | 96 | out += identity 97 | out = self.relu(out) 98 | 99 | return out 100 | 101 | 102 | class ResNet(nn.Module): 103 | 104 | def __init__(self, block, layers, num_classes=10, zero_init_residual=False, 105 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 106 | norm_layer=None): 107 | super(ResNet, self).__init__() 108 | if norm_layer is None: 109 | norm_layer = nn.BatchNorm2d 110 | self._norm_layer = norm_layer 111 | 112 | self.inplanes = 64 113 | self.dilation = 1 114 | if replace_stride_with_dilation is None: 115 | # each element in the tuple indicates if we should replace 116 | # the 2x2 stride with a dilated convolution instead 117 | replace_stride_with_dilation = [False, False, False] 118 | if len(replace_stride_with_dilation) != 3: 119 | raise ValueError("replace_stride_with_dilation should be None " 120 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 121 | self.groups = groups 122 | self.base_width = width_per_group 123 | 124 | ## CIFAR10: kernel_size 7 -> 3, stride 2 -> 1, padding 3->1 125 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 126 | ## END 127 | 128 | self.bn1 = norm_layer(self.inplanes) 129 | self.relu = nn.ReLU(inplace=True) 130 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 131 | self.layer1 = self._make_layer(block, 64, layers[0]) 132 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 133 | dilate=replace_stride_with_dilation[0]) 134 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 135 | dilate=replace_stride_with_dilation[1]) 136 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 137 | dilate=replace_stride_with_dilation[2]) 138 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 139 | self.fc = nn.Linear(512 * block.expansion, num_classes) 140 | 141 | for m in self.modules(): 142 | if isinstance(m, nn.Conv2d): 143 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 144 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 145 | nn.init.constant_(m.weight, 1) 146 | nn.init.constant_(m.bias, 0) 147 | 148 | # Zero-initialize the last BN in each residual branch, 149 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 150 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 151 | if zero_init_residual: 152 | for m in self.modules(): 153 | if isinstance(m, Bottleneck): 154 | nn.init.constant_(m.bn3.weight, 0) 155 | elif isinstance(m, BasicBlock): 156 | nn.init.constant_(m.bn2.weight, 0) 157 | 158 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 159 | norm_layer = self._norm_layer 160 | downsample = None 161 | previous_dilation = self.dilation 162 | if dilate: 163 | self.dilation *= stride 164 | stride = 1 165 | if stride != 1 or self.inplanes != planes * block.expansion: 166 | downsample = nn.Sequential( 167 | conv1x1(self.inplanes, planes * block.expansion, stride), 168 | norm_layer(planes * block.expansion), 169 | ) 170 | 171 | layers = [] 172 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 173 | self.base_width, previous_dilation, norm_layer)) 174 | self.inplanes = planes * block.expansion 175 | for _ in range(1, blocks): 176 | layers.append(block(self.inplanes, planes, groups=self.groups, 177 | base_width=self.base_width, dilation=self.dilation, 178 | norm_layer=norm_layer)) 179 | 180 | return nn.Sequential(*layers) 181 | 182 | def forward(self, x): 183 | x = self.conv1(x) 184 | x = self.bn1(x) 185 | x = self.relu(x) 186 | x = self.maxpool(x) 187 | 188 | x = self.layer1(x) 189 | x = self.layer2(x) 190 | x = self.layer3(x) 191 | x = self.layer4(x) 192 | 193 | x = self.avgpool(x) 194 | x = x.reshape(x.size(0), -1) 195 | x = self.fc(x) 196 | 197 | return x 198 | 199 | 200 | 201 | def _resnet(arch, block, layers, pretrained, progress, device, **kwargs): 202 | model = ResNet(block, layers, **kwargs) 203 | if pretrained: 204 | script_dir = os.path.dirname(__file__) 205 | state_dict = torch.load(script_dir + '/state_dicts/'+arch+'.pt', map_location=device) 206 | model.load_state_dict(state_dict) 207 | return model 208 | 209 | 210 | def resnet18(pretrained=False, progress=True, device='cpu', **kwargs): 211 | """Constructs a ResNet-18 model. 212 | 213 | Args: 214 | pretrained (bool): If True, returns a model pre-trained on ImageNet 215 | progress (bool): If True, displays a progress bar of the download to stderr 216 | """ 217 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, device, 218 | **kwargs) 219 | 220 | 221 | def resnet34(pretrained=False, progress=True, device='cpu', **kwargs): 222 | """Constructs a ResNet-34 model. 223 | 224 | Args: 225 | pretrained (bool): If True, returns a model pre-trained on ImageNet 226 | progress (bool): If True, displays a progress bar of the download to stderr 227 | """ 228 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, device, 229 | **kwargs) 230 | 231 | 232 | def resnet50(pretrained=False, progress=True, device='cpu', **kwargs): 233 | """Constructs a ResNet-50 model. 234 | 235 | Args: 236 | pretrained (bool): If True, returns a model pre-trained on ImageNet 237 | progress (bool): If True, displays a progress bar of the download to stderr 238 | """ 239 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, device, 240 | **kwargs) 241 | 242 | 243 | def resnet101(pretrained=False, progress=True, device='cpu', **kwargs): 244 | """Constructs a ResNet-101 model. 245 | 246 | Args: 247 | pretrained (bool): If True, returns a model pre-trained on ImageNet 248 | progress (bool): If True, displays a progress bar of the download to stderr 249 | """ 250 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, device, 251 | **kwargs) 252 | 253 | 254 | def resnet152(pretrained=False, progress=True, device='cpu', **kwargs): 255 | """Constructs a ResNet-152 model. 256 | 257 | Args: 258 | pretrained (bool): If True, returns a model pre-trained on ImageNet 259 | progress (bool): If True, displays a progress bar of the download to stderr 260 | """ 261 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, device, 262 | **kwargs) 263 | 264 | 265 | def resnext50_32x4d(pretrained=False, progress=True, device='cpu', **kwargs): 266 | """Constructs a ResNeXt-50 32x4d model. 267 | 268 | Args: 269 | pretrained (bool): If True, returns a model pre-trained on ImageNet 270 | progress (bool): If True, displays a progress bar of the download to stderr 271 | """ 272 | kwargs['groups'] = 32 273 | kwargs['width_per_group'] = 4 274 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 275 | pretrained, progress, device, **kwargs) 276 | 277 | 278 | def resnext101_32x8d(pretrained=False, progress=True, device='cpu', **kwargs): 279 | """Constructs a ResNeXt-101 32x8d model. 280 | 281 | Args: 282 | pretrained (bool): If True, returns a model pre-trained on ImageNet 283 | progress (bool): If True, displays a progress bar of the download to stderr 284 | """ 285 | kwargs['groups'] = 32 286 | kwargs['width_per_group'] = 8 287 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 288 | pretrained, progress, device, **kwargs) 289 | -------------------------------------------------------------------------------- /cifar10_models/resnet_orig.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import os 5 | 6 | #Credit to https://github.com/akamaster/pytorch_resnet_cifar10 7 | 8 | __all__ = ['resnet_orig'] 9 | 10 | class LambdaLayer(nn.Module): 11 | def __init__(self, lambd): 12 | super(LambdaLayer, self).__init__() 13 | self.lambd = lambd 14 | 15 | def forward(self, x): 16 | return self.lambd(x) 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, in_planes, planes, stride=1, option='A'): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | 28 | self.shortcut = nn.Sequential() 29 | if stride != 1 or in_planes != planes: 30 | if option == 'A': 31 | """ 32 | For CIFAR10 ResNet paper uses option A. 33 | """ 34 | self.shortcut = LambdaLayer(lambda x: 35 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 36 | elif option == 'B': 37 | self.shortcut = nn.Sequential( 38 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 39 | nn.BatchNorm2d(self.expansion * planes) 40 | ) 41 | 42 | def forward(self, x): 43 | out = F.relu(self.bn1(self.conv1(x))) 44 | out = self.bn2(self.conv2(out)) 45 | out += self.shortcut(x) 46 | out = F.relu(out) 47 | return out 48 | 49 | class ResNet(nn.Module): 50 | def __init__(self, block, num_blocks, num_classes=10): 51 | super(ResNet, self).__init__() 52 | self.in_planes = 16 53 | 54 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 55 | self.bn1 = nn.BatchNorm2d(16) 56 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 57 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 58 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 59 | self.linear = nn.Linear(64, num_classes) 60 | 61 | def _make_layer(self, block, planes, num_blocks, stride): 62 | strides = [stride] + [1]*(num_blocks-1) 63 | layers = [] 64 | for stride in strides: 65 | layers.append(block(self.in_planes, planes, stride)) 66 | self.in_planes = planes * block.expansion 67 | 68 | return nn.Sequential(*layers) 69 | 70 | def forward(self, x): 71 | out = F.relu(self.bn1(self.conv1(x))) 72 | out = self.layer1(out) 73 | out = self.layer2(out) 74 | out = self.layer3(out) 75 | out = F.avg_pool2d(out, out.size()[3]) 76 | out = out.view(out.size(0), -1) 77 | out = self.linear(out) 78 | return out 79 | 80 | def resnet_orig(pretrained=True, device='cpu'): 81 | net = ResNet(BasicBlock, [3, 3, 3]) 82 | if pretrained: 83 | script_dir = os.path.dirname(__file__) 84 | state_dict = torch.load(script_dir + '/state_dicts/resnet_orig.pt', map_location=device) 85 | net.load_state_dict(state_dict) 86 | return net -------------------------------------------------------------------------------- /cifar10_models/vgg.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | __all__ = [ 7 | "VGG", 8 | "vgg11_bn", 9 | "vgg13_bn", 10 | "vgg16_bn", 11 | "vgg19_bn", 12 | ] 13 | 14 | 15 | class VGG(nn.Module): 16 | def __init__(self, features, num_classes=10, init_weights=True): 17 | super(VGG, self).__init__() 18 | self.features = features 19 | # CIFAR 10 (7, 7) to (1, 1) 20 | # self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 21 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 22 | 23 | self.classifier = nn.Sequential( 24 | nn.Linear(512 * 1 * 1, 4096), 25 | # nn.Linear(512 * 7 * 7, 4096), 26 | nn.ReLU(True), 27 | nn.Dropout(), 28 | nn.Linear(4096, 4096), 29 | nn.ReLU(True), 30 | nn.Dropout(), 31 | nn.Linear(4096, num_classes), 32 | ) 33 | if init_weights: 34 | self._initialize_weights() 35 | 36 | def forward(self, x): 37 | x = self.features(x) 38 | x = self.avgpool(x) 39 | x = x.view(x.size(0), -1) 40 | x = self.classifier(x) 41 | return x 42 | 43 | def _initialize_weights(self): 44 | for m in self.modules(): 45 | if isinstance(m, nn.Conv2d): 46 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 47 | if m.bias is not None: 48 | nn.init.constant_(m.bias, 0) 49 | elif isinstance(m, nn.BatchNorm2d): 50 | nn.init.constant_(m.weight, 1) 51 | nn.init.constant_(m.bias, 0) 52 | elif isinstance(m, nn.Linear): 53 | nn.init.normal_(m.weight, 0, 0.01) 54 | nn.init.constant_(m.bias, 0) 55 | 56 | 57 | def make_layers(cfg, batch_norm=False): 58 | layers = [] 59 | in_channels = 3 60 | for v in cfg: 61 | if v == "M": 62 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 63 | else: 64 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 65 | if batch_norm: 66 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 67 | else: 68 | layers += [conv2d, nn.ReLU(inplace=True)] 69 | in_channels = v 70 | return nn.Sequential(*layers) 71 | 72 | 73 | cfgs = { 74 | "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], 75 | "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], 76 | "D": [ 77 | 64, 78 | 64, 79 | "M", 80 | 128, 81 | 128, 82 | "M", 83 | 256, 84 | 256, 85 | 256, 86 | "M", 87 | 512, 88 | 512, 89 | 512, 90 | "M", 91 | 512, 92 | 512, 93 | 512, 94 | "M", 95 | ], 96 | "E": [ 97 | 64, 98 | 64, 99 | "M", 100 | 128, 101 | 128, 102 | "M", 103 | 256, 104 | 256, 105 | 256, 106 | 256, 107 | "M", 108 | 512, 109 | 512, 110 | 512, 111 | 512, 112 | "M", 113 | 512, 114 | 512, 115 | 512, 116 | 512, 117 | "M", 118 | ], 119 | } 120 | 121 | 122 | def _vgg(arch, cfg, batch_norm, pretrained, progress, device, **kwargs): 123 | if pretrained: 124 | kwargs["init_weights"] = False 125 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 126 | if pretrained: 127 | script_dir = os.path.dirname(__file__) 128 | state_dict = torch.load( 129 | script_dir + "/state_dicts/" + arch + ".pt", map_location=device 130 | ) 131 | model.load_state_dict(state_dict) 132 | return model 133 | 134 | 135 | def vgg11_bn(pretrained=False, progress=True, device="cpu", **kwargs): 136 | """VGG 11-layer model (configuration "A") with batch normalization 137 | 138 | Args: 139 | pretrained (bool): If True, returns a model pre-trained on ImageNet 140 | progress (bool): If True, displays a progress bar of the download to stderr 141 | """ 142 | return _vgg("vgg11_bn", "A", True, pretrained, progress, device, **kwargs) 143 | 144 | 145 | def vgg13_bn(pretrained=False, progress=True, device="cpu", **kwargs): 146 | """VGG 13-layer model (configuration "B") with batch normalization 147 | 148 | Args: 149 | pretrained (bool): If True, returns a model pre-trained on ImageNet 150 | progress (bool): If True, displays a progress bar of the download to stderr 151 | """ 152 | return _vgg("vgg13_bn", "B", True, pretrained, progress, device, **kwargs) 153 | 154 | 155 | def vgg16_bn(pretrained=False, progress=True, device="cpu", **kwargs): 156 | """VGG 16-layer model (configuration "D") with batch normalization 157 | 158 | Args: 159 | pretrained (bool): If True, returns a model pre-trained on ImageNet 160 | progress (bool): If True, displays a progress bar of the download to stderr 161 | """ 162 | return _vgg("vgg16_bn", "D", True, pretrained, progress, device, **kwargs) 163 | 164 | 165 | def vgg19_bn(pretrained=False, progress=True, device="cpu", **kwargs): 166 | """VGG 19-layer model (configuration 'E') with batch normalization 167 | 168 | Args: 169 | pretrained (bool): If True, returns a model pre-trained on ImageNet 170 | progress (bool): If True, displays a progress bar of the download to stderr 171 | """ 172 | return _vgg("vgg19_bn", "E", True, pretrained, progress, device, **kwargs) -------------------------------------------------------------------------------- /config/chmnist_dense.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "chmnist", 3 | "model_name": "densenet121", 4 | "epochs": 100, 5 | "num_cls": 8, 6 | "input_dim": 3, 7 | "optimizer": "adam", 8 | "lr": 0.001, 9 | "weight_decay": 5e-4 10 | } -------------------------------------------------------------------------------- /config/chmnist_resnet18.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "chmnist", 3 | "model_name": "resnet18", 4 | "epochs": 100, 5 | "num_cls": 8, 6 | "input_dim": 3, 7 | "optimizer": "adam", 8 | "lr": 0.001, 9 | "weight_decay": 5e-4 10 | } -------------------------------------------------------------------------------- /config/chmnist_vgg16.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "chmnist", 3 | "model_name": "vgg16bn", 4 | "epochs": 100, 5 | "num_cls": 8, 6 | "input_dim": 3, 7 | "optimizer": "adam", 8 | "lr": 0.001, 9 | "weight_decay": 5e-4 10 | } -------------------------------------------------------------------------------- /config/cifar100_dense.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "cifar100", 3 | "model_name": "densenet121", 4 | "epochs": 100, 5 | "num_cls": 100, 6 | "input_dim": 3, 7 | "optimizer": "adam", 8 | "lr": 0.001, 9 | "weight_decay": 5e-4 10 | } -------------------------------------------------------------------------------- /config/cifar100_resnet18.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "cifar100", 3 | "model_name": "resnet18", 4 | "epochs": 100, 5 | "num_cls": 100, 6 | "input_dim": 3, 7 | "optimizer": "adam", 8 | "lr": 0.001, 9 | "weight_decay": 5e-4 10 | } -------------------------------------------------------------------------------- /config/cifar100_vgg16.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "cifar100", 3 | "model_name": "vgg16bn", 4 | "epochs": 100, 5 | "num_cls": 100, 6 | "input_dim": 3, 7 | "optimizer": "adam", 8 | "lr": 0.001, 9 | "weight_decay": 5e-4 10 | } -------------------------------------------------------------------------------- /config/cifar10_dense.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "cifar10", 3 | "model_name": "densenet121", 4 | "epochs": 100, 5 | "num_cls": 10, 6 | "input_dim": 3, 7 | "optimizer": "adam", 8 | "lr": 0.001, 9 | "weight_decay": 5e-4 10 | } -------------------------------------------------------------------------------- /config/cifar10_resnet18.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "cifar10", 3 | "model_name": "resnet18", 4 | "epochs": 100, 5 | "num_cls": 10, 6 | "input_dim": 3, 7 | "optimizer": "adam", 8 | "lr": 0.001, 9 | "weight_decay": 5e-4 10 | } -------------------------------------------------------------------------------- /config/cifar10_vgg16.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "cifar10", 3 | "model_name": "vgg16bn", 4 | "epochs": 100, 5 | "num_cls": 10, 6 | "input_dim": 3, 7 | "optimizer": "adam", 8 | "lr": 0.001, 9 | "weight_decay": 5e-4 10 | } 11 | 12 | -------------------------------------------------------------------------------- /config/location.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "location", 3 | "model_name": "columnfc", 4 | "epochs": 50, 5 | "num_cls": 30, 6 | "input_dim": 446, 7 | "optimizer": "adam", 8 | "lr": 0.001, 9 | "weight_decay": 5e-4 10 | } -------------------------------------------------------------------------------- /config/purchase.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "purchase100", 3 | "model_name": "columnfc", 4 | "epochs": 50, 5 | "num_cls": 100, 6 | "input_dim": 600, 7 | "optimizer": "adam", 8 | "lr": 0.001, 9 | "weight_decay": 5e-4 10 | } -------------------------------------------------------------------------------- /config/svhn_dense.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "svhn", 3 | "model_name": "densenet121", 4 | "epochs": 100, 5 | "num_cls": 10, 6 | "input_dim": 3, 7 | "optimizer": "adam", 8 | "lr": 0.001, 9 | "weight_decay": 5e-4 10 | } -------------------------------------------------------------------------------- /config/svhn_resnet18.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "svhn", 3 | "model_name": "resnet18", 4 | "epochs": 100, 5 | "num_cls": 10, 6 | "input_dim": 3, 7 | "optimizer": "adam", 8 | "lr": 0.001, 9 | "weight_decay": 5e-4 10 | } -------------------------------------------------------------------------------- /config/svhn_vgg16.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "svhn", 3 | "model_name": "vgg16bn", 4 | "epochs": 100, 5 | "num_cls": 10, 6 | "input_dim": 3, 7 | "optimizer": "adam", 8 | "lr": 0.001, 9 | "weight_decay": 5e-4 10 | } -------------------------------------------------------------------------------- /config/texas100.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "texas100", 3 | "model_name": "columnfc", 4 | "epochs": 50, 5 | "num_cls": 100, 6 | "input_dim": 6169, 7 | "optimizer": "adam", 8 | "lr": 0.001, 9 | "weight_decay": 5e-4 10 | } -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import sklearn 4 | import torch 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | import os 8 | from torch.utils.data import TensorDataset 9 | 10 | 11 | def get_dataset(name, train=True): 12 | print(f"Build Dataset {name}") 13 | if name == "cifar10": 14 | mean = (0.4914, 0.4822, 0.4465) 15 | std = (0.2023, 0.1994, 0.2010) 16 | transform = transforms.Compose([ 17 | transforms.ToTensor(), 18 | transforms.Normalize(mean, std), 19 | ]) 20 | dataset = torchvision.datasets.CIFAR10( 21 | root='/data/datasets/cifar10-data', train=train, download=True, transform=transform) 22 | elif name == "cifar100": 23 | mean = (0.4914, 0.4822, 0.4465) 24 | std = (0.2023, 0.1994, 0.2010) 25 | transform = transforms.Compose([ 26 | transforms.ToTensor(), 27 | transforms.Normalize(mean, std), 28 | ]) 29 | dataset = torchvision.datasets.CIFAR100(root='/data/datasets/cifar100-data', train=train, download=True, 30 | transform=transform) 31 | elif name == "mnist": 32 | mean = (0.1307,) 33 | std = (0.3081,) 34 | transform = transforms.Compose([transforms.ToTensor(), 35 | transforms.Normalize(mean, std) 36 | ]) 37 | dataset = torchvision.datasets.MNIST(root='/data/datasets/mnist-data', train=train, download=True, 38 | transform=transform) 39 | 40 | elif name == "svhn": 41 | transform = transforms.Compose([ 42 | transforms.ToTensor(), 43 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 44 | ]) 45 | dataset = torchvision.datasets.SVHN(root='/data/datasets/svhn-data', split='train' if train else "test", 46 | download=True, transform=transform) 47 | 48 | elif name == "texas100": 49 | # the dataset can be downloaded from https://www.comp.nus.edu.sg/~reza/files/dataset_texas.tgz 50 | dataset = np.load("/data/datasets/texas/data_complete.npz") 51 | x_data = torch.tensor(dataset['x'][:, :]).float() 52 | y_data = torch.tensor(dataset['y'][:] - 1).long() 53 | if train: 54 | dataset = TensorDataset(x_data, y_data) 55 | else: 56 | dataset = None 57 | 58 | elif name == "location": 59 | # the dataset can be downloaded from https://github.com/jjy1994/MemGuard/tree/master/data/location 60 | dataset = np.load("/data/datasets/location/data_complete.npz") 61 | x_data = torch.tensor(dataset['x'][:, :]).float() 62 | y_data = torch.tensor(dataset['y'][:] - 1).long() 63 | if train: 64 | dataset = TensorDataset(x_data, y_data) 65 | else: 66 | dataset = None 67 | 68 | elif name == "purchase100": 69 | # the dataset can be downloaded from https://www.comp.nus.edu.sg/~reza/files/dataset_purchase.tgz 70 | tensor_path = "/data/datasets/purchase100/purchase100.pt" 71 | if os.path.exists(tensor_path): 72 | data = torch.load(tensor_path) 73 | x_data, y_data = data['x'], data['y'] 74 | else: 75 | dataset = np.loadtxt("/data/datasets/purchase100/purchase100.txt", delimiter=',') 76 | x_data = torch.tensor(dataset[:, :-1]).float() 77 | y_data = torch.tensor(dataset[:, - 1]).long() 78 | torch.save({'x': x_data, 'y': y_data}, tensor_path) 79 | if train: 80 | dataset = TensorDataset(x_data, y_data) 81 | else: 82 | dataset = None 83 | 84 | elif name == "chmnist": 85 | # the dataset can be downloaded from https://zenodo.org/record/53169/files/Kather_texture_2016_image_tiles_5000.zip?download=1 86 | data_folder = "/data/datasets/chmnist/kather_texture_2016_image_tiles_5000/Kather_texture_2016_image_tiles_5000" 87 | mean = (0.485, 0.456, 0.406) 88 | std = (0.229, 0.224, 0.225) 89 | transform = transforms.Compose([ 90 | transforms.Resize((32, 32)), 91 | transforms.ToTensor(), 92 | transforms.Normalize(mean, std), 93 | ]) 94 | if train: 95 | dataset = torchvision.datasets.ImageFolder(data_folder, transform=transform) 96 | else: 97 | dataset = None 98 | else: 99 | raise ValueError 100 | 101 | return dataset 102 | 103 | -------------------------------------------------------------------------------- /figs/attack_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Machine-Learning-Security-Lab/mia_prune/db3a7af9fb36949ff135fb0faea8b57754d6e6af/figs/attack_pipeline.png -------------------------------------------------------------------------------- /mia.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import numpy as np 4 | import pickle 5 | import random 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | from torch.utils.data import ConcatDataset, DataLoader, Subset 9 | from base_model import BaseModel 10 | from datasets import get_dataset 11 | from attackers import MiaAttack 12 | 13 | parser = argparse.ArgumentParser(description='Membership inference Attacks on Network Pruning') 14 | parser.add_argument('device', default=0, type=int, help="GPU id to use") 15 | parser.add_argument('config_path', default=0, type=str, help="config file path") 16 | parser.add_argument('--dataset_name', default='mnist', type=str) 17 | parser.add_argument('--model_name', default='mnist', type=str) 18 | parser.add_argument('--num_cls', default=10, type=int) 19 | parser.add_argument('--input_dim', default=1, type=int) 20 | parser.add_argument('--image_size', default=28, type=int) 21 | parser.add_argument('--hidden_size', default=128, type=int) 22 | parser.add_argument('--seed', default=7, type=int) 23 | parser.add_argument('--early_stop', default=5, type=int) 24 | parser.add_argument('--batch_size', default=128, type=int) 25 | parser.add_argument('--prune_epochs', default=50, type=int) 26 | parser.add_argument('--pruner_name', default='l1unstructure', type=str, help="prune method for victim model") 27 | parser.add_argument('--prune_sparsity', default=0.7, type=float, help="prune sparsity for victim model") 28 | parser.add_argument('--adaptive', action='store_true', help="use adaptive attack") 29 | parser.add_argument('--shadow_num', default=5, type=int) 30 | parser.add_argument('--defend', default='', type=str) 31 | parser.add_argument('--defend_arg', default=4, type=float) 32 | parser.add_argument('--attacks', default="samia", type=str) 33 | parser.add_argument('--original', action='store_true', help="original=true, then launch attack against original model") 34 | 35 | 36 | def main(args): 37 | # import torch.multiprocessing 38 | # torch.multiprocessing.set_sharing_strategy('file_system') 39 | 40 | torch.manual_seed(args.seed) 41 | random.seed(args.seed) 42 | np.random.seed(args.seed) 43 | device = f"cuda:{args.device}" 44 | cudnn.benchmark = True 45 | prune_prefix = f"{args.pruner_name}_{args.prune_sparsity}" \ 46 | f"{'_' + args.defend if args.defend else ''}{'_' + str(args.defend_arg) if args.defend else ''}" 47 | prune_prefix2 = f"{args.pruner_name}_{args.prune_sparsity}" \ 48 | f"{'_' + args.defend if args.adaptive else ''}{'_' + str(args.defend_arg) if args.adaptive else ''}" 49 | 50 | save_folder = f"results/{args.dataset_name}_{args.model_name}" 51 | 52 | print(f"Save Folder: {save_folder}") 53 | 54 | # Load datasets 55 | trainset = get_dataset(args.dataset_name, train=True) 56 | testset = get_dataset(args.dataset_name, train=False) 57 | if testset is None: 58 | total_dataset = trainset 59 | else: 60 | total_dataset = ConcatDataset([trainset, testset]) 61 | total_size = len(total_dataset) 62 | data_path = f"{save_folder}/data_index.pkl" 63 | with open(data_path, 'rb') as f: 64 | victim_train_list, victim_dev_list, victim_test_list, attack_split_list \ 65 | = pickle.load(f) 66 | victim_train_dataset = Subset(total_dataset, victim_train_list) 67 | victim_test_dataset = Subset(total_dataset, victim_test_list) 68 | print(f"Total Data Size: {total_size}, " 69 | f"Victim Train Size: {len(victim_train_list)}, " 70 | f"Victim Test Size: {len(victim_test_list)}") 71 | victim_train_loader = DataLoader(victim_train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, 72 | pin_memory=False) 73 | victim_test_loader = DataLoader(victim_test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, 74 | pin_memory=False) 75 | 76 | # Load pruned victim model 77 | victim_model_save_folder = save_folder + "/victim_model" 78 | victim_model_path = f"{victim_model_save_folder}/best.pth" 79 | victim_model = BaseModel(args.model_name, num_cls=args.num_cls, input_dim=args.input_dim, device=device) 80 | victim_model.load(victim_model_path) 81 | 82 | pruned_model_save_folder = f"{save_folder}/{prune_prefix}_model" 83 | print(f"Load Pruned Model from {pruned_model_save_folder}") 84 | victim_pruned_model = BaseModel( 85 | args.model_name, num_cls=args.num_cls, input_dim=args.input_dim, save_folder=pruned_model_save_folder, 86 | device=device) 87 | victim_pruned_model.model.load_state_dict(torch.load(f"{pruned_model_save_folder}/best.pth")) 88 | victim_pruned_model.test(victim_train_loader, "Victim Pruned Model Train") 89 | victim_pruned_model.test(victim_test_loader, "Victim Pruned Model Test") 90 | 91 | # Load pruned shadow models 92 | shadow_model_list, shadow_prune_model_list, shadow_train_loader_list, shadow_test_loader_list = [], [], [], [] 93 | for shadow_ind in range(args.shadow_num): 94 | attack_train_list, attack_dev_list, attack_test_list = attack_split_list[shadow_ind] 95 | shadow_train_dataset = Subset(total_dataset, attack_train_list) 96 | shadow_dev_dataset = Subset(total_dataset, attack_dev_list) 97 | shadow_test_dataset = Subset(total_dataset, attack_test_list) 98 | shadow_train_loader = DataLoader(shadow_train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, 99 | pin_memory=False) 100 | shadow_dev_loader = DataLoader(shadow_dev_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, 101 | pin_memory=False) 102 | shadow_test_loader = DataLoader(shadow_test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, 103 | pin_memory=False) 104 | 105 | shadow_model_path = f"{save_folder}/shadow_model_{shadow_ind}/best.pth" 106 | shadow_model = BaseModel(args.model_name, num_cls=args.num_cls, input_dim=args.input_dim, device=device) 107 | shadow_model.load(shadow_model_path) 108 | pruned_shadow_model_save_folder = f"{save_folder}/shadow_{prune_prefix2}_model_{shadow_ind}" 109 | print(f"Load Pruned Shadow Model From {pruned_shadow_model_save_folder}") 110 | shadow_pruned_model = BaseModel( 111 | args.model_name, num_cls=args.num_cls, input_dim=args.input_dim, 112 | save_folder=pruned_shadow_model_save_folder, device=device) 113 | shadow_pruned_model.model.load_state_dict(torch.load(f"{pruned_shadow_model_save_folder}/best.pth")) 114 | shadow_pruned_model.test(shadow_train_loader, "Shadow Pruned Model Train") 115 | shadow_pruned_model.test(shadow_test_loader, "Shadow Pruned Model Test") 116 | 117 | shadow_model_list.append(shadow_model) 118 | shadow_prune_model_list.append(shadow_pruned_model) 119 | shadow_train_loader_list.append(shadow_train_loader) 120 | shadow_test_loader_list.append(shadow_test_loader) 121 | 122 | print("Start Membership Inference Attacks") 123 | 124 | if args.original: 125 | attack_original = True 126 | else: 127 | attack_original = False 128 | attacker = MiaAttack( 129 | victim_model, victim_pruned_model, victim_train_loader, victim_test_loader, 130 | shadow_model_list, shadow_prune_model_list, shadow_train_loader_list, shadow_test_loader_list, 131 | num_cls=args.num_cls, device=device, batch_size=args.batch_size, 132 | attack_original=attack_original) 133 | 134 | attacks = args.attacks.split(',') 135 | 136 | if "samia" in attacks: 137 | nn_trans_acc = attacker.nn_attack("nn_sens_cls", model_name="transformer") 138 | print(f"SAMIA attack accuracy {nn_trans_acc:.3f}") 139 | 140 | if "threshold" in attacks: 141 | conf, xent, mentr, top1_conf = attacker.threshold_attack() 142 | print(f"Ground-truth class confidence-based threshold attack (Conf) accuracy: {conf:.3f}") 143 | print(f"Cross-entropy-based threshold attack (Xent) accuracy: {xent:.3f}") 144 | print(f"Modified-entropy-based threshold attack (Mentr) accuracy: {mentr:.3f}") 145 | print(f"Top1 Confidence-based threshold attack (Top1-conf) accuracy: {top1_conf:.3f}") 146 | 147 | if "nn" in attacks: 148 | nn_acc = attacker.nn_attack("nn") 149 | print(f"NN attack accuracy {nn_acc:.3f}") 150 | 151 | if "nn_top3" in attacks: 152 | nn_top3_acc = attacker.nn_attack("nn_top3") 153 | print(f"Top3-NN Attack Accuracy {nn_top3_acc}") 154 | 155 | if "nn_cls" in attacks: 156 | nn_cls_acc = attacker.nn_attack("nn_cls") 157 | print(f"NNCls Attack Accuracy {nn_cls_acc}") 158 | 159 | 160 | if __name__ == '__main__': 161 | args = parser.parse_args() 162 | with open(args.config_path) as f: 163 | t_args = argparse.Namespace() 164 | t_args.__dict__.update(json.load(f)) 165 | args = parser.parse_args(namespace=t_args) 166 | 167 | print(args) 168 | main(args) 169 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torch import nn 3 | 4 | 5 | class ColumnFC(nn.Module): 6 | def __init__(self, input_dim=100, output_dim=100, dropout=0.1): 7 | super(ColumnFC, self).__init__() 8 | self.fc1 = nn.Linear(input_dim, 256) 9 | self.drop1 = nn.Dropout(dropout) 10 | self.fc2 = nn.Linear(256, 128) 11 | self.fc3 = nn.Linear(128, output_dim) 12 | self.drop2 = nn.Dropout(dropout) 13 | 14 | def forward(self, x): 15 | x = self.drop1(F.relu(self.fc1(x))) 16 | x = self.drop2(F.relu(self.fc2(x))) 17 | x = self.fc3(x) 18 | return x 19 | 20 | 21 | class MIAFC(nn.Module): 22 | def __init__(self, input_dim=10, output_dim=1, dropout=0.2): 23 | super(MIAFC, self).__init__() 24 | self.fc1 = nn.Linear(input_dim, 512) 25 | self.dropout1 = nn.Dropout(dropout) 26 | self.fc2 = nn.Linear(512, 256) 27 | self.dropout2 = nn.Dropout(dropout) 28 | self.fc3 = nn.Linear(256, 128) 29 | self.fc4 = nn.Linear(128, output_dim) 30 | 31 | def forward(self, x): 32 | x = F.relu(self.fc1(x)) 33 | x = self.dropout1(x) 34 | x = F.relu(self.fc2(x)) 35 | x = self.dropout2(x) 36 | x = F.relu(self.fc3(x)) 37 | x = self.fc4(x) 38 | return x 39 | -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import numpy as np 4 | import os 5 | import pickle 6 | import shutil 7 | import random 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | from sklearn.model_selection import train_test_split 11 | from torch.utils.data import ConcatDataset, DataLoader, Subset 12 | from base_model import BaseModel 13 | from datasets import get_dataset 14 | from utils import seed_worker 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('device', default=0, type=int, help="GPU id to use") 18 | parser.add_argument('config_path', default=0, type=str, help="config file path") 19 | parser.add_argument('--dataset_name', default='mnist', type=str) 20 | parser.add_argument('--model_name', default='mnist', type=str) 21 | parser.add_argument('--num_cls', default=10, type=int) 22 | parser.add_argument('--input_dim', default=1, type=int) 23 | parser.add_argument('--image_size', default=28, type=int) 24 | parser.add_argument('--hidden_size', default=128, type=int) 25 | parser.add_argument('--seed', default=7, type=int) 26 | parser.add_argument('--batch_size', default=128, type=int) 27 | parser.add_argument('--epochs', default=50, type=int) 28 | parser.add_argument('--early_stop', default=5, type=int, help="patience for early stopping") 29 | parser.add_argument('--lr', default=0.001, type=float) 30 | parser.add_argument('--weight_decay', default=5e-4, type=float) 31 | parser.add_argument('--optimizer', default="adam", type=str) 32 | parser.add_argument('--shadow_num', default=5, type=int) 33 | 34 | 35 | def main(args): 36 | torch.manual_seed(args.seed) 37 | random.seed(args.seed) 38 | np.random.seed(args.seed) 39 | 40 | device = f"cuda:{args.device}" 41 | cudnn.benchmark = True 42 | save_folder = f"results/{args.dataset_name}_{args.model_name}" 43 | 44 | print(f"Save Folder: {save_folder}") 45 | trainset = get_dataset(args.dataset_name, train=True) 46 | testset = get_dataset(args.dataset_name, train=False) 47 | if testset is None: 48 | total_dataset = trainset 49 | else: 50 | total_dataset = ConcatDataset([trainset, testset]) 51 | total_size = len(total_dataset) 52 | data_path = f"{save_folder}/data_index.pkl" 53 | 54 | # split the dataset into victim dataset and shadow dataset, then split each into train, val, test 55 | if not os.path.exists(save_folder): 56 | os.mkdir(save_folder) 57 | victim_list, attack_list = train_test_split(list(range(total_size)), test_size=0.5, random_state=args.seed) 58 | victim_train_list, victim_test_list = train_test_split(victim_list, test_size=0.45, random_state=args.seed) 59 | victim_train_list, victim_dev_list = train_test_split( 60 | victim_train_list, test_size=0.1818, random_state=args.seed) 61 | attack_split_list = [] 62 | for i in range(args.shadow_num): 63 | attack_train_list, attack_test_list = train_test_split( 64 | attack_list, test_size=0.45, random_state=args.seed + i) 65 | attack_train_list, attack_dev_list = train_test_split( 66 | attack_train_list, test_size=0.1818, random_state=args.seed + i) 67 | attack_split_list.append([attack_train_list, attack_dev_list, attack_test_list]) 68 | with open(data_path, 'wb') as f: 69 | pickle.dump([victim_train_list, victim_dev_list, victim_test_list, attack_split_list], f) 70 | 71 | # train and prune the victim model 72 | victim_train_dataset = Subset(total_dataset, victim_train_list) 73 | victim_dev_dataset = Subset(total_dataset, victim_dev_list) 74 | victim_test_dataset = Subset(total_dataset, victim_test_list) 75 | 76 | print(f"Total Data Size: {total_size}, " 77 | f"Victim Train Size: {len(victim_train_list)}, " 78 | f"Victim Dev Size: {len(victim_dev_list)}, " 79 | f"Victim Test Size: {len(victim_test_list)}") 80 | victim_train_loader = DataLoader(victim_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, 81 | pin_memory=True, worker_init_fn=seed_worker) 82 | victim_dev_loader = DataLoader(victim_dev_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, 83 | pin_memory=True, worker_init_fn=seed_worker) 84 | victim_test_loader = DataLoader(victim_test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, 85 | pin_memory=True, worker_init_fn=seed_worker) 86 | 87 | victim_model_save_folder = save_folder + "/victim_model" 88 | 89 | print("Train Victim Model") 90 | if not os.path.exists(victim_model_save_folder): 91 | os.makedirs(victim_model_save_folder) 92 | victim_model = BaseModel( 93 | args.model_name, num_cls=args.num_cls, input_dim=args.input_dim, save_folder=victim_model_save_folder, 94 | device=device, optimizer=args.optimizer, lr=args.lr, weight_decay=args.weight_decay) 95 | best_acc = 0 96 | count = 0 97 | for epoch in range(args.epochs): 98 | train_acc, train_loss = victim_model.train(victim_train_loader, f"Epoch {epoch} Train") 99 | dev_acc, dev_loss = victim_model.test(victim_dev_loader, f"Epoch {epoch} Dev") 100 | test_acc, test_loss = victim_model.test(victim_test_loader, f"Epoch {epoch} Test") 101 | if dev_acc > best_acc: 102 | best_acc = dev_acc 103 | save_path = victim_model.save(epoch, test_acc, test_loss) 104 | best_path = save_path 105 | count = 0 106 | elif args.early_stop > 0: 107 | count += 1 108 | if count > args.early_stop: 109 | print(f"Early Stop at Epoch {epoch}") 110 | break 111 | shutil.copyfile(best_path, f"{victim_model_save_folder}/best.pth") 112 | 113 | # Train shadow models 114 | for shadow_ind in range(args.shadow_num): 115 | attack_train_list, attack_dev_list, attack_test_list = attack_split_list[shadow_ind] 116 | attack_train_dataset = Subset(total_dataset, attack_train_list) 117 | attack_dev_dataset = Subset(total_dataset, attack_dev_list) 118 | attack_test_dataset = Subset(total_dataset, attack_test_list) 119 | attack_train_loader = DataLoader(attack_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, 120 | pin_memory=True, worker_init_fn=seed_worker) 121 | attack_dev_loader = DataLoader(attack_dev_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, 122 | pin_memory=True, worker_init_fn=seed_worker) 123 | attack_test_loader = DataLoader(attack_test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, 124 | pin_memory=True, worker_init_fn=seed_worker) 125 | 126 | print(f"Train Shadow Model {shadow_ind}") 127 | shadow_model_save_folder = f"{save_folder}/shadow_model_{shadow_ind}" 128 | if not os.path.exists(shadow_model_save_folder): 129 | os.makedirs(shadow_model_save_folder) 130 | shadow_model = BaseModel( 131 | args.model_name, num_cls=args.num_cls, input_dim=args.input_dim, save_folder=shadow_model_save_folder, 132 | device=device, optimizer=args.optimizer, lr=args.lr, weight_decay=args.weight_decay) 133 | best_acc = 0 134 | count = 0 135 | for epoch in range(args.epochs): 136 | train_acc, train_loss = shadow_model.train(attack_train_loader, f"Epoch {epoch} Shadow Train") 137 | dev_acc, dev_loss = shadow_model.test(attack_dev_loader, f"Epoch {epoch} Shadow Dev") 138 | test_acc, test_loss = shadow_model.test(attack_test_loader, f"Epoch {epoch} Shadow Test") 139 | if dev_acc > best_acc: 140 | best_acc = dev_acc 141 | save_path = shadow_model.save(epoch, test_acc, test_loss) 142 | best_path = save_path 143 | count = 0 144 | elif args.early_stop > 0: 145 | count += 1 146 | if count > args.early_stop: 147 | print(f"Early Stop at Epoch {epoch}") 148 | break 149 | 150 | shutil.copyfile(best_path, f"{shadow_model_save_folder}/best.pth") 151 | 152 | 153 | if __name__ == '__main__': 154 | args = parser.parse_args() 155 | with open(args.config_path) as f: 156 | t_args = argparse.Namespace() 157 | t_args.__dict__.update(json.load(f)) 158 | args = parser.parse_args(namespace=t_args) 159 | args.prune_epochs = int(args.epochs) // 2 160 | 161 | print(args) 162 | main(args) 163 | -------------------------------------------------------------------------------- /prune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import json 4 | import numpy as np 5 | import os 6 | import pickle 7 | import random 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | from torch.utils.data import ConcatDataset, DataLoader, Subset 11 | from base_model import BaseModel 12 | from datasets import get_dataset 13 | from pruner import get_pruner 14 | from utils import seed_worker 15 | from pyvacy import optim, analysis 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('device', default=0, type=int, help="GPU id to use") 19 | parser.add_argument('config_path', default=0, type=str, help="config file path") 20 | parser.add_argument('--dataset_name', default='mnist', type=str) 21 | parser.add_argument('--model_name', default='mnist', type=str) 22 | parser.add_argument('--num_cls', default=10, type=int) 23 | parser.add_argument('--input_dim', default=1, type=int) 24 | parser.add_argument('--image_size', default=28, type=int) 25 | parser.add_argument('--hidden_size', default=128, type=int) 26 | parser.add_argument('--seed', default=7, type=int) 27 | parser.add_argument('--batch_size', default=128, type=int) 28 | parser.add_argument('--epochs', default=50, type=int) 29 | parser.add_argument('--early_stop', default=5, type=int, help="patience for early stopping") 30 | parser.add_argument('--lr', default=0.001, type=float) 31 | parser.add_argument('--weight_decay', default=5e-4, type=float) 32 | parser.add_argument('--optimizer', default="adam", type=str) 33 | parser.add_argument('--prune_epochs', default=50, type=int) 34 | parser.add_argument('--pruner_name', default='l1unstructure', type=str) 35 | parser.add_argument('--prune_sparsity', default=0.7, type=float) 36 | parser.add_argument('--defend', default="", type=str, help="'' if no defense, else ppb") 37 | parser.add_argument('--adaptive', action='store_true') 38 | parser.add_argument('--shadow_num', default=5, type=int) 39 | parser.add_argument('--defend_arg', default=4, type=float) 40 | 41 | 42 | def main(args): 43 | torch.manual_seed(args.seed) 44 | random.seed(args.seed) 45 | np.random.seed(args.seed) 46 | 47 | device = f"cuda:{args.device}" 48 | cudnn.benchmark = True 49 | prune_lr = args.lr 50 | if args.defend == "": 51 | prune_prefix = f"{args.pruner_name}_{args.prune_sparsity}" 52 | else: 53 | prune_prefix = f"{args.pruner_name}_{args.prune_sparsity}_{args.defend}_{args.defend_arg}" 54 | 55 | save_folder = f"results/{args.dataset_name}_{args.model_name}" 56 | 57 | print(f"Save Folder: {save_folder}") 58 | trainset = get_dataset(args.dataset_name, train=True) 59 | testset = get_dataset(args.dataset_name, train=False) 60 | if testset is None: 61 | total_dataset = trainset 62 | else: 63 | total_dataset = ConcatDataset([trainset, testset]) 64 | total_size = len(total_dataset) 65 | data_path = f"{save_folder}/data_index.pkl" 66 | 67 | # load data split for the pretrained victim and shadow model 68 | with open(data_path, 'rb') as f: 69 | victim_train_list, victim_dev_list, victim_test_list, attack_split_list \ 70 | = pickle.load(f) 71 | 72 | # train and prune the victim model 73 | victim_train_dataset = Subset(total_dataset, victim_train_list) 74 | victim_dev_dataset = Subset(total_dataset, victim_dev_list) 75 | victim_test_dataset = Subset(total_dataset, victim_test_list) 76 | 77 | print(f"Total Data Size: {total_size}, " 78 | f"Victim Train Size: {len(victim_train_list)}, " 79 | f"Victim Dev Size: {len(victim_dev_list)}, " 80 | f"Victim Test Size: {len(victim_test_list)}") 81 | victim_train_loader = DataLoader(victim_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, 82 | pin_memory=True, worker_init_fn=seed_worker) 83 | victim_dev_loader = DataLoader(victim_dev_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, 84 | pin_memory=True, worker_init_fn=seed_worker) 85 | victim_test_loader = DataLoader(victim_test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, 86 | pin_memory=True, worker_init_fn=seed_worker) 87 | 88 | victim_model_save_folder = save_folder + "/victim_model" 89 | # load pretrained model 90 | victim_model_path = f"{victim_model_save_folder}/best.pth" 91 | victim_model = BaseModel(args.model_name, num_cls=args.num_cls, input_dim=args.input_dim, device=device) 92 | victim_model.load(victim_model_path) 93 | test_acc, test_loss = victim_model.test(victim_test_loader, "Pretrained Victim") 94 | 95 | victim_acc = test_acc 96 | 97 | print("Prune Victim Model") 98 | pruned_model_save_folder = f"{save_folder}/{prune_prefix}_model" 99 | victim_model_path = f"{victim_model_save_folder}/best.pth" 100 | victim_model.load(victim_model_path) 101 | 102 | org_state = copy.deepcopy(victim_model.model.state_dict()) 103 | if not os.path.exists(pruned_model_save_folder): 104 | os.makedirs(pruned_model_save_folder) 105 | 106 | # prune victim model 107 | if args.defend == "adv": 108 | attack_model_type = "mia_fc" 109 | else: 110 | attack_model_type = "" 111 | 112 | victim_pruned_model = BaseModel( 113 | args.model_name, num_cls=args.num_cls, input_dim=args.input_dim, lr=prune_lr, 114 | weight_decay=args.weight_decay, save_folder=pruned_model_save_folder, device=device, 115 | optimizer=args.optimizer, attack_model_type=attack_model_type) 116 | victim_pruned_model.model.load_state_dict(org_state) 117 | pruner = get_pruner(args.pruner_name, victim_pruned_model.model, sparsity=args.prune_sparsity) 118 | victim_pruned_model.model = pruner.compress() 119 | 120 | best_acc = 0 121 | count = 0 122 | for epoch in range(args.prune_epochs): 123 | pruner.update_epoch(epoch) 124 | if args.defend == "": 125 | train_acc, train_loss = victim_pruned_model.train(victim_train_loader, f"Epoch {epoch} Prune Train") 126 | elif args.defend == "ppb": 127 | train_acc, train_loss = victim_pruned_model.train_defend_ppb( 128 | victim_train_loader, log_pref=f"Epoch {epoch} Victim Prune Train With PPB", defend_arg=args.defend_arg) 129 | elif args.defend == "adv": 130 | train_acc, train_loss = victim_pruned_model.train_defend_adv( 131 | victim_train_loader, victim_dev_loader, log_pref=f"Epoch {epoch} Victim Prune Train With ADV", 132 | privacy_theta=args.defend_arg) 133 | dev_acc, dev_loss = victim_pruned_model.test(victim_dev_loader, f"Epoch {epoch} Prune Dev") 134 | test_acc, test_loss = victim_pruned_model.test(victim_test_loader, f"Epoch {epoch} Prune Test") 135 | 136 | if dev_acc > best_acc: 137 | best_acc = dev_acc 138 | pruner.export_model(model_path=f"{pruned_model_save_folder}/best.pth", 139 | mask_path=f"{pruned_model_save_folder}/best_mask.pth") 140 | count = 0 141 | elif args.early_stop > 0: 142 | count += 1 143 | if count > args.early_stop: 144 | print(f"Early Stop at Epoch {epoch}") 145 | break 146 | 147 | victim_prune_acc = test_acc 148 | 149 | # prune shadow models 150 | shadow_acc_list = [] 151 | shadow_prune_acc_list = [] 152 | for shadow_ind in range(args.shadow_num): 153 | attack_train_list, attack_dev_list, attack_test_list = attack_split_list[shadow_ind] 154 | attack_train_dataset = Subset(total_dataset, attack_train_list) 155 | attack_dev_dataset = Subset(total_dataset, attack_dev_list) 156 | attack_test_dataset = Subset(total_dataset, attack_test_list) 157 | attack_train_loader = DataLoader(attack_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, 158 | pin_memory=True, worker_init_fn=seed_worker) 159 | attack_dev_loader = DataLoader(attack_dev_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, 160 | pin_memory=True, worker_init_fn=seed_worker) 161 | attack_test_loader = DataLoader(attack_test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, 162 | pin_memory=True, worker_init_fn=seed_worker) 163 | 164 | # load pretrained shadow model 165 | shadow_model_path = f"{save_folder}/shadow_model_{shadow_ind}/best.pth" 166 | shadow_model = BaseModel(args.model_name, num_cls=args.num_cls, input_dim=args.input_dim, device=device) 167 | shadow_model.load(shadow_model_path) 168 | test_acc, _ = shadow_model.test(attack_test_loader, f"Pretrain Shadow") 169 | shadow_acc = test_acc 170 | 171 | org_state = copy.deepcopy(shadow_model.model.state_dict()) 172 | pruned_shadow_model_save_folder = \ 173 | f"{save_folder}/shadow_{prune_prefix}_model_{shadow_ind}" 174 | if not os.path.exists(pruned_shadow_model_save_folder): 175 | os.makedirs(pruned_shadow_model_save_folder) 176 | 177 | # prune shadow models 178 | shadow_pruned_model = BaseModel( 179 | args.model_name, num_cls=args.num_cls, input_dim=args.input_dim, lr=prune_lr, 180 | weight_decay=args.weight_decay, save_folder=pruned_shadow_model_save_folder, device=device, 181 | optimizer=args.optimizer, attack_model_type=attack_model_type) 182 | shadow_pruned_model.model.load_state_dict(org_state) 183 | pruner = get_pruner(args.pruner_name, shadow_pruned_model.model, sparsity=args.prune_sparsity,) 184 | shadow_pruned_model.model = pruner.compress() 185 | 186 | best_acc = 0 187 | count = 0 188 | for epoch in range(args.prune_epochs): 189 | pruner.update_epoch(epoch) 190 | if args.defend == "": 191 | train_acc, train_loss = shadow_pruned_model.train( 192 | attack_train_loader, f"Epoch {epoch} Shadow Prune Train") 193 | elif args.defend == "ppb": 194 | train_acc, train_loss = shadow_pruned_model.train_defend_ppb( 195 | attack_train_loader, f"Epoch {epoch} Shadow Prune Train With PPB", defend_arg=args.defend_arg) 196 | elif args.defend == "adv": 197 | train_acc, train_loss = shadow_pruned_model.train_defend_adv( 198 | attack_train_loader, attack_dev_loader, log_pref=f"Epoch {epoch} Victim Prune Train With ADV", 199 | privacy_theta=args.defend_arg) 200 | dev_acc, dev_loss = shadow_pruned_model.test(attack_dev_loader, f"Epoch {epoch} Shadow Prune Dev") 201 | test_acc, test_loss = shadow_pruned_model.test(attack_test_loader, f"Epoch {epoch} Shadow Prune Test") 202 | 203 | if dev_acc > best_acc: 204 | best_acc = dev_acc 205 | pruner.export_model(model_path=f"{pruned_shadow_model_save_folder}/best.pth", 206 | mask_path=f"{pruned_shadow_model_save_folder}/best_mask.pth") 207 | count = 0 208 | elif args.early_stop > 0: 209 | count += 1 210 | if count > args.early_stop: 211 | print(f"Early Stop at Epoch {epoch}") 212 | break 213 | 214 | shadow_prune_acc = test_acc 215 | shadow_acc_list.append(shadow_acc), shadow_prune_acc_list.append(shadow_prune_acc) 216 | return victim_acc, victim_prune_acc, np.mean(shadow_acc_list), np.mean(shadow_prune_acc_list) 217 | 218 | 219 | if __name__ == '__main__': 220 | args = parser.parse_args() 221 | with open(args.config_path) as f: 222 | t_args = argparse.Namespace() 223 | t_args.__dict__.update(json.load(f)) 224 | args = parser.parse_args(namespace=t_args) 225 | args.prune_epochs = int(args.epochs) // 2 226 | 227 | print(args) 228 | main(args) 229 | -------------------------------------------------------------------------------- /prune_dp.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import json 4 | import numpy as np 5 | import os 6 | import pickle 7 | import random 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | from torch.utils.data import ConcatDataset, DataLoader, Subset, TensorDataset 11 | from base_model import BaseModel 12 | from datasets import get_dataset 13 | from pruner import get_pruner 14 | from utils import seed_worker 15 | from pyvacy import optim, analysis, sampling 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('device', default=0, type=int, help="GPU id to use") 19 | parser.add_argument('config_path', default=0, type=str, help="config file path") 20 | parser.add_argument('--dataset_name', default='mnist', type=str) 21 | parser.add_argument('--model_name', default='mnist', type=str) 22 | parser.add_argument('--num_cls', default=10, type=int) 23 | parser.add_argument('--input_dim', default=1, type=int) 24 | parser.add_argument('--image_size', default=28, type=int) 25 | parser.add_argument('--hidden_size', default=128, type=int) 26 | parser.add_argument('--seed', default=7, type=int) 27 | parser.add_argument('--batch_size', default=128, type=int) 28 | parser.add_argument('--epochs', default=50, type=int) 29 | parser.add_argument('--early_stop', default=5, type=int, help="patience for early stopping") 30 | parser.add_argument('--lr', default=0.001, type=float) 31 | parser.add_argument('--weight_decay', default=5e-4, type=float) 32 | parser.add_argument('--optimizer', default="adam", type=str) 33 | parser.add_argument('--prune_epochs', default=50, type=int) 34 | parser.add_argument('--pruner_name', default='l1unstructure', type=str) 35 | parser.add_argument('--prune_sparsity', default=0.7, type=float) 36 | parser.add_argument('--defend', default="dp", type=str, help="DPSGD algorithm") 37 | parser.add_argument('--adaptive', action='store_true') 38 | parser.add_argument('--shadow_num', default=5, type=int) 39 | parser.add_argument('--defend_arg', default=0.1, type=float) 40 | 41 | 42 | def main(args): 43 | torch.manual_seed(args.seed) 44 | random.seed(args.seed) 45 | np.random.seed(args.seed) 46 | 47 | device = f"cuda:{args.device}" 48 | cudnn.benchmark = True 49 | prune_lr = args.lr 50 | 51 | minibatch_size = args.batch_size 52 | microbatch_size = args.batch_size // 2 53 | dp_training_parameters = { 54 | 'minibatch_size': minibatch_size, 'l2_norm_clip': 1.0, 'noise_multiplier': args.defend_arg, 55 | 'microbatch_size': microbatch_size, 'lr': args.lr, 'weight_decay': args.weight_decay} 56 | 57 | if args.defend == "": 58 | prune_prefix = f"{args.pruner_name}_{args.prune_sparsity}" 59 | else: 60 | prune_prefix = f"{args.pruner_name}_{args.prune_sparsity}_{args.defend}_{args.defend_arg}" 61 | 62 | save_folder = f"results/{args.dataset_name}_{args.model_name}" 63 | 64 | print(f"Save Folder: {save_folder}") 65 | trainset = get_dataset(args.dataset_name, train=True) 66 | testset = get_dataset(args.dataset_name, train=False) 67 | if testset is None: 68 | total_dataset = trainset 69 | else: 70 | total_dataset = ConcatDataset([trainset, testset]) 71 | total_size = len(total_dataset) 72 | data_path = f"{save_folder}/data_index.pkl" 73 | 74 | # load data split for the pretrained victim and shadow model 75 | with open(data_path, 'rb') as f: 76 | victim_train_list, victim_dev_list, victim_test_list, attack_split_list \ 77 | = pickle.load(f) 78 | 79 | # train and prune the victim model 80 | victim_train_dataset = Subset(total_dataset, victim_train_list) 81 | victim_dev_dataset = Subset(total_dataset, victim_dev_list) 82 | victim_test_dataset = Subset(total_dataset, victim_test_list) 83 | 84 | print(f"Total Data Size: {total_size}, " 85 | f"Victim Train Size: {len(victim_train_list)}, " 86 | f"Victim Dev Size: {len(victim_dev_list)}, " 87 | f"Victim Test Size: {len(victim_test_list)}") 88 | victim_train_loader = DataLoader(victim_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, 89 | pin_memory=True, worker_init_fn=seed_worker) 90 | victim_dev_loader = DataLoader(victim_dev_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, 91 | pin_memory=True, worker_init_fn=seed_worker) 92 | victim_test_loader = DataLoader(victim_test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, 93 | pin_memory=True, worker_init_fn=seed_worker) 94 | 95 | victim_model_save_folder = save_folder + "/victim_model" 96 | # load pretrained model 97 | victim_model_path = f"{victim_model_save_folder}/best.pth" 98 | victim_model = BaseModel(args.model_name, num_cls=args.num_cls, input_dim=args.input_dim, device=device) 99 | victim_model.load(victim_model_path) 100 | test_acc, test_loss = victim_model.test(victim_test_loader, "Pretrained Victim") 101 | 102 | victim_acc = test_acc 103 | 104 | print("Prune Victim Model") 105 | pruned_victim_model_save_folder = f"{save_folder}/{prune_prefix}_model" 106 | victim_model_path = f"{victim_model_save_folder}/best.pth" 107 | victim_model.load(victim_model_path) 108 | 109 | org_state = copy.deepcopy(victim_model.model.state_dict()) 110 | if not os.path.exists(pruned_victim_model_save_folder): 111 | os.makedirs(pruned_victim_model_save_folder) 112 | 113 | # prune victim model 114 | victim_pruned_model = BaseModel( 115 | args.model_name, num_cls=args.num_cls, input_dim=args.input_dim, lr=prune_lr, 116 | weight_decay=args.weight_decay, save_folder=pruned_victim_model_save_folder, device=device, 117 | optimizer=args.optimizer) 118 | victim_pruned_model.model.load_state_dict(org_state) 119 | pruner = get_pruner(args.pruner_name, victim_pruned_model.model, sparsity=args.prune_sparsity) 120 | victim_pruned_model.model = pruner.compress() 121 | 122 | iterations = len(victim_train_dataset) // args.batch_size * args.epochs 123 | 124 | victim_optimizer = optim.DPSGD(params=victim_pruned_model.model.parameters(), **dp_training_parameters) 125 | # delta = 1e-5 126 | # print('Achieves ({}, {})-DP'.format( 127 | # analysis.epsilon( 128 | # len(victim_train_dataset), args.batch_size, args.defend_arg, 129 | # iterations, delta 130 | # ), 131 | # delta, 132 | # )) 133 | 134 | best_acc = 0 135 | count = 0 136 | minibatch_loader, microbatch_loader = sampling.get_data_loaders(minibatch_size, microbatch_size, iterations) 137 | 138 | victim_pruned_model.model.train() 139 | for epoch in range(args.prune_epochs): 140 | pruner.update_epoch(epoch) 141 | total_loss = 0 142 | total = 0 143 | for X_minibatch, y_minibatch in minibatch_loader(victim_train_dataset): 144 | victim_optimizer.zero_grad() 145 | for X_microbatch, y_microbatch in microbatch_loader(TensorDataset(X_minibatch, y_minibatch)): 146 | X_microbatch, y_microbatch = X_microbatch.to(device), y_microbatch.to(device) 147 | victim_optimizer.zero_microbatch_grad() 148 | loss = victim_pruned_model.criterion(victim_pruned_model.model(X_microbatch), y_microbatch) 149 | loss.backward() 150 | victim_optimizer.microbatch_step() 151 | size = X_microbatch.size(0) 152 | total_loss += loss.item() * size 153 | total += size 154 | victim_optimizer.step() 155 | print(f"Epoch {epoch} Prune Train: Loss {total_loss/total}") 156 | dev_acc, dev_loss = victim_pruned_model.test(victim_dev_loader, f"Epoch {epoch} Prune Dev") 157 | test_acc, test_loss = victim_pruned_model.test(victim_test_loader, f"Epoch {epoch} Prune Test") 158 | 159 | if dev_acc > best_acc: 160 | best_acc = dev_acc 161 | pruner.export_model(model_path=f"{pruned_victim_model_save_folder}/best.pth", 162 | mask_path=f"{pruned_victim_model_save_folder}/best_mask.pth") 163 | count = 0 164 | elif args.early_stop > 0: 165 | count += 1 166 | if count > args.early_stop: 167 | print(f"Early Stop at Epoch {epoch}") 168 | break 169 | 170 | victim_prune_acc = test_acc 171 | 172 | # prune shadow models 173 | shadow_acc_list = [] 174 | shadow_prune_acc_list = [] 175 | for shadow_ind in range(args.shadow_num): 176 | attack_train_list, attack_dev_list, attack_test_list = attack_split_list[shadow_ind] 177 | attack_train_dataset = Subset(total_dataset, attack_train_list) 178 | attack_dev_dataset = Subset(total_dataset, attack_dev_list) 179 | attack_test_dataset = Subset(total_dataset, attack_test_list) 180 | attack_train_loader = DataLoader(attack_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, 181 | pin_memory=True, worker_init_fn=seed_worker) 182 | attack_dev_loader = DataLoader(attack_dev_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, 183 | pin_memory=True, worker_init_fn=seed_worker) 184 | attack_test_loader = DataLoader(attack_test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, 185 | pin_memory=True, worker_init_fn=seed_worker) 186 | 187 | # load pretrained shadow model 188 | shadow_model_path = f"{save_folder}/shadow_model_{shadow_ind}/best.pth" 189 | shadow_model = BaseModel(args.model_name, num_cls=args.num_cls, input_dim=args.input_dim, device=device) 190 | shadow_model.load(shadow_model_path) 191 | test_acc, _ = shadow_model.test(attack_test_loader, f"Pretrain Shadow") 192 | shadow_acc = test_acc 193 | 194 | org_state = copy.deepcopy(shadow_model.model.state_dict()) 195 | pruned_shadow_model_save_folder = \ 196 | f"{save_folder}/shadow_{prune_prefix}_model_{shadow_ind}" 197 | if not os.path.exists(pruned_shadow_model_save_folder): 198 | os.makedirs(pruned_shadow_model_save_folder) 199 | 200 | # prune shadow models 201 | shadow_pruned_model = BaseModel( 202 | args.model_name, num_cls=args.num_cls, input_dim=args.input_dim, lr=prune_lr, 203 | weight_decay=args.weight_decay, 204 | save_folder=pruned_shadow_model_save_folder, device=device, optimizer=args.optimizer) 205 | shadow_pruned_model.model.load_state_dict(org_state) 206 | pruner = get_pruner(args.pruner_name, shadow_pruned_model.model, sparsity=args.prune_sparsity,) 207 | shadow_pruned_model.model = pruner.compress() 208 | 209 | shadow_optimizer = optim.DPSGD(params=shadow_pruned_model.model.parameters(), **dp_training_parameters) 210 | best_acc = 0 211 | count = 0 212 | minibatch_loader, microbatch_loader = sampling.get_data_loaders(minibatch_size, microbatch_size, iterations) 213 | 214 | shadow_pruned_model.model.train() 215 | for epoch in range(args.prune_epochs): 216 | pruner.update_epoch(epoch) 217 | total_loss = 0 218 | total = 0 219 | for X_minibatch, y_minibatch in minibatch_loader(attack_train_dataset): 220 | shadow_optimizer.zero_grad() 221 | for X_microbatch, y_microbatch in microbatch_loader(TensorDataset(X_minibatch, y_minibatch)): 222 | X_microbatch, y_microbatch = X_microbatch.to(device), y_microbatch.to(device) 223 | shadow_optimizer.zero_microbatch_grad() 224 | loss = shadow_pruned_model.criterion(shadow_pruned_model.model(X_microbatch), y_microbatch) 225 | loss.backward() 226 | shadow_optimizer.microbatch_step() 227 | size = X_microbatch.size(0) 228 | total_loss += loss.item() * size 229 | total += size 230 | shadow_optimizer.step() 231 | print(f"Epoch {epoch} Prune Train: Loss {total_loss / total}") 232 | dev_acc, dev_loss = shadow_pruned_model.test(attack_dev_loader, f"Epoch {epoch} Prune Dev") 233 | test_acc, test_loss = shadow_pruned_model.test(attack_test_loader, f"Epoch {epoch} Prune Test") 234 | 235 | if dev_acc > best_acc: 236 | best_acc = dev_acc 237 | pruner.export_model(model_path=f"{pruned_shadow_model_save_folder}/best.pth", 238 | mask_path=f"{pruned_shadow_model_save_folder}/best_mask.pth") 239 | count = 0 240 | elif args.early_stop > 0: 241 | count += 1 242 | if count > args.early_stop: 243 | print(f"Early Stop at Epoch {epoch}") 244 | break 245 | 246 | shadow_prune_acc = test_acc 247 | shadow_acc_list.append(shadow_acc), shadow_prune_acc_list.append(shadow_prune_acc) 248 | return victim_acc, victim_prune_acc, np.mean(shadow_acc_list), np.mean(shadow_prune_acc_list) 249 | 250 | 251 | if __name__ == '__main__': 252 | args = parser.parse_args() 253 | with open(args.config_path) as f: 254 | t_args = argparse.Namespace() 255 | t_args.__dict__.update(json.load(f)) 256 | args = parser.parse_args(namespace=t_args) 257 | args.prune_epochs = int(args.epochs) // 2 258 | 259 | print(args) 260 | main(args) 261 | -------------------------------------------------------------------------------- /pruner.py: -------------------------------------------------------------------------------- 1 | from nni.algorithms.compression.pytorch.pruning import LevelPruner, SlimPruner, L2FilterPruner, \ 2 | L1FilterPruner 3 | 4 | 5 | def get_pruner(pruner_name, model, sparsity=0.5): 6 | if pruner_name == "l1unstructure": 7 | config_list = [{ 8 | 'sparsity': sparsity, 9 | 'op_types': ["default"] 10 | }] 11 | return LevelPruner(model, config_list) 12 | elif pruner_name == "slim": 13 | config_list = [{ 14 | 'sparsity': sparsity, 15 | 'op_types': ["BatchNorm2d"] 16 | }] 17 | return SlimPruner(model, config_list) 18 | elif pruner_name == "l1structure": 19 | config_list = [{ 20 | 'sparsity': sparsity, 21 | 'op_types': ['Conv2d'] 22 | }] 23 | return L1FilterPruner(model, config_list) 24 | elif pruner_name == "l2structure": 25 | config_list = [{ 26 | 'sparsity': sparsity, 27 | 'op_types': ['Conv2d'] 28 | }] 29 | return L2FilterPruner(model, config_list) 30 | else: 31 | raise ValueError 32 | -------------------------------------------------------------------------------- /pyvacy/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | original code is from https://github.com/ChrisWaites/pyvacy 3 | """ -------------------------------------------------------------------------------- /pyvacy/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | from .epsilon_calculation import * 2 | -------------------------------------------------------------------------------- /pyvacy/analysis/epsilon_calculation.py: -------------------------------------------------------------------------------- 1 | import math 2 | from .rdp_accountant import compute_rdp, get_privacy_spent 3 | 4 | 5 | def epsilon(N, batch_size, noise_multiplier, iterations, delta=1e-5): 6 | """Calculates epsilon for stochastic gradient descent. 7 | 8 | Args: 9 | N (int): Total numbers of examples 10 | batch_size (int): Batch size 11 | noise_multiplier (float): Noise multiplier for DP-SGD 12 | delta (float): Target delta 13 | 14 | Returns: 15 | float: epsilon 16 | 17 | Example:: 18 | >>> epsilon(10000, 256, 0.3, 100, 1e-5) 19 | """ 20 | q = batch_size / N 21 | optimal_order = _ternary_search(lambda order: _apply_dp_sgd_analysis(q, noise_multiplier, iterations, [order], delta), 1, 512, 72) 22 | return _apply_dp_sgd_analysis(q, noise_multiplier, iterations, [optimal_order], delta) 23 | 24 | 25 | def _apply_dp_sgd_analysis(q, sigma, iterations, orders, delta): 26 | """Calculates epsilon for stochastic gradient descent. 27 | 28 | Args: 29 | q (float): Sampling probability, generally batch_size / number_of_samples 30 | sigma (float): Noise multiplier 31 | iterations (float): Number of iterations mechanism is applied 32 | orders (list(float)): Orders to try for finding optimal epsilon 33 | delta (float): Target delta 34 | 35 | Returns: 36 | float: epsilon 37 | 38 | Example:: 39 | >>> epsilon(10000, 256, 0.3, 100, 1e-5) 40 | """ 41 | rdp = compute_rdp(q, sigma, iterations, orders) 42 | eps, _, opt_order = get_privacy_spent(orders, rdp, target_delta=delta) 43 | return eps 44 | 45 | 46 | def _ternary_search(f, left, right, iterations): 47 | """Performs a search over a closed domain [left, right] for the value which minimizes f.""" 48 | for i in range(iterations): 49 | left_third = left + (right - left) / 3 50 | right_third = right - (right - left) / 3 51 | if f(left_third) < f(right_third): 52 | right = right_third 53 | else: 54 | left = left_third 55 | return (left + right) / 2 56 | 57 | -------------------------------------------------------------------------------- /pyvacy/analysis/moments_accountant.py: -------------------------------------------------------------------------------- 1 | import math 2 | from pyvacy.analysis.rdp_accountant import compute_rdp, get_privacy_spent 3 | 4 | def epsilon(N, batch_size, noise_multiplier, epochs, delta=1e-5): 5 | """Calculates epsilon for stochastic gradient descent. 6 | 7 | Args: 8 | N (int): Total numbers of examples 9 | batch_size (int): Batch size 10 | noise_multiplier (float): Noise multiplier for DP-SGD 11 | epochs (float): number of epochs (may be fractional) 12 | delta (float): Target delta 13 | 14 | Returns: 15 | float: epsilon 16 | 17 | Example:: 18 | >>> epsilon(10000, 256, 0.3, 100, 1e-5) 19 | """ 20 | q = batch_size / N 21 | steps = int(math.ceil(epochs * N / batch_size)) 22 | optimal_order = _ternary_search(lambda order: _apply_dp_sgd_analysis(q, noise_multiplier, steps, [order], delta), 1, 512, 0.1) 23 | return _apply_dp_sgd_analysis(q, noise_multiplier, steps, [optimal_order], delta) 24 | 25 | 26 | def _apply_dp_sgd_analysis(q, sigma, steps, orders, delta): 27 | """Calculates epsilon for stochastic gradient descent. 28 | 29 | Args: 30 | q (float): Sampling probability, generally batch_size / number_of_samples 31 | sigma (float): Noise multiplier 32 | steps (float): Number of steps mechanism is applied 33 | orders (list(float)): Orders to try for finding optimal epsilon 34 | delta (float): Target delta 35 | 36 | Returns: 37 | float: epsilon 38 | 39 | Example:: 40 | >>> epsilon(10000, 256, 0.3, 100, 1e-5) 41 | """ 42 | rdp = compute_rdp(q, sigma, steps, orders) 43 | eps, _, opt_order = get_privacy_spent(orders, rdp, target_delta=delta) 44 | return eps 45 | 46 | def _ternary_search(f, left, right, precision): 47 | """Performs a search over a closed domain [left, right] for the value which minimizes f.""" 48 | while True: 49 | if abs(right - left) < precision: 50 | return (left + right) / 2 51 | 52 | left_third = left + (right - left) / 3 53 | right_third = right - (right - left) / 3 54 | 55 | if f(left_third) < f(right_third): 56 | right = right_third 57 | else: 58 | left = left_third 59 | 60 | -------------------------------------------------------------------------------- /pyvacy/analysis/rdp_accountant.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """RDP analysis of the Sampled Gaussian Mechanism. 16 | 17 | Functionality for computing Renyi differential privacy (RDP) of an additive 18 | Sampled Gaussian Mechanism (SGM). Its public interface consists of two methods: 19 | compute_rdp(q, noise_multiplier, T, orders) computes RDP for SGM iterated 20 | T times. 21 | get_privacy_spent(orders, rdp, target_eps, target_delta) computes delta 22 | (or eps) given RDP at multiple orders and 23 | a target value for eps (or delta). 24 | 25 | Example use: 26 | 27 | Suppose that we have run an SGM applied to a function with l2-sensitivity 1. 28 | Its parameters are given as a list of tuples (q1, sigma1, T1), ..., 29 | (qk, sigma_k, Tk), and we wish to compute eps for a given delta. 30 | The example code would be: 31 | 32 | max_order = 32 33 | orders = range(2, max_order + 1) 34 | rdp = np.zeros_like(orders, dtype=float) 35 | for q, sigma, T in parameters: 36 | rdp += rdp_accountant.compute_rdp(q, sigma, T, orders) 37 | eps, _, opt_order = rdp_accountant.get_privacy_spent(rdp, target_delta=delta) 38 | """ 39 | from __future__ import absolute_import 40 | from __future__ import division 41 | from __future__ import print_function 42 | 43 | import math 44 | import sys 45 | 46 | import numpy as np 47 | from scipy import special 48 | import six 49 | 50 | ######################## 51 | # LOG-SPACE ARITHMETIC # 52 | ######################## 53 | 54 | 55 | def _log_add(logx, logy): 56 | """Add two numbers in the log space.""" 57 | a, b = min(logx, logy), max(logx, logy) 58 | if a == -np.inf: # adding 0 59 | return b 60 | # Use exp(a) + exp(b) = (exp(a - b) + 1) * exp(b) 61 | return math.log1p(math.exp(a - b)) + b # log1p(x) = log(x + 1) 62 | 63 | 64 | def _log_sub(logx, logy): 65 | """Subtract two numbers in the log space. Answer must be non-negative.""" 66 | if logx < logy: 67 | raise ValueError("The result of subtraction must be non-negative.") 68 | if logy == -np.inf: # subtracting 0 69 | return logx 70 | if logx == logy: 71 | return -np.inf # 0 is represented as -np.inf in the log space. 72 | 73 | try: 74 | # Use exp(x) - exp(y) = (exp(x - y) - 1) * exp(y). 75 | return math.log(math.expm1(logx - logy)) + logy # expm1(x) = exp(x) - 1 76 | except OverflowError: 77 | return logx 78 | 79 | 80 | def _log_print(logx): 81 | """Pretty print.""" 82 | if logx < math.log(sys.float_info.max): 83 | return "{}".format(math.exp(logx)) 84 | else: 85 | return "exp({})".format(logx) 86 | 87 | 88 | def _compute_log_a_int(q, sigma, alpha): 89 | """Compute log(A_alpha) for integer alpha. 0 < q < 1.""" 90 | assert isinstance(alpha, six.integer_types) 91 | 92 | # Initialize with 0 in the log space. 93 | log_a = -np.inf 94 | 95 | for i in range(alpha + 1): 96 | log_coef_i = ( 97 | math.log(special.binom(alpha, i)) + i * math.log(q) + 98 | (alpha - i) * math.log(1 - q)) 99 | 100 | s = log_coef_i + (i * i - i) / (2 * (sigma**2)) 101 | log_a = _log_add(log_a, s) 102 | 103 | return float(log_a) 104 | 105 | 106 | def _compute_log_a_frac(q, sigma, alpha): 107 | """Compute log(A_alpha) for fractional alpha. 0 < q < 1.""" 108 | # The two parts of A_alpha, integrals over (-inf,z0] and [z0, +inf), are 109 | # initialized to 0 in the log space: 110 | log_a0, log_a1 = -np.inf, -np.inf 111 | i = 0 112 | 113 | z0 = sigma**2 * math.log(1 / q - 1) + .5 114 | 115 | while True: # do ... until loop 116 | coef = special.binom(alpha, i) 117 | log_coef = math.log(abs(coef)) 118 | j = alpha - i 119 | 120 | log_t0 = log_coef + i * math.log(q) + j * math.log(1 - q) 121 | log_t1 = log_coef + j * math.log(q) + i * math.log(1 - q) 122 | 123 | log_e0 = math.log(.5) + _log_erfc((i - z0) / (math.sqrt(2) * sigma)) 124 | log_e1 = math.log(.5) + _log_erfc((z0 - j) / (math.sqrt(2) * sigma)) 125 | 126 | log_s0 = log_t0 + (i * i - i) / (2 * (sigma**2)) + log_e0 127 | log_s1 = log_t1 + (j * j - j) / (2 * (sigma**2)) + log_e1 128 | 129 | if coef > 0: 130 | log_a0 = _log_add(log_a0, log_s0) 131 | log_a1 = _log_add(log_a1, log_s1) 132 | else: 133 | log_a0 = _log_sub(log_a0, log_s0) 134 | log_a1 = _log_sub(log_a1, log_s1) 135 | 136 | i += 1 137 | if max(log_s0, log_s1) < -30: 138 | break 139 | 140 | return _log_add(log_a0, log_a1) 141 | 142 | 143 | def _compute_log_a(q, sigma, alpha): 144 | """Compute log(A_alpha) for any positive finite alpha.""" 145 | if float(alpha).is_integer(): 146 | return _compute_log_a_int(q, sigma, int(alpha)) 147 | else: 148 | return _compute_log_a_frac(q, sigma, alpha) 149 | 150 | 151 | def _log_erfc(x): 152 | """Compute log(erfc(x)) with high accuracy for large x.""" 153 | try: 154 | return math.log(2) + special.log_ndtr(-x * 2**.5) 155 | except NameError: 156 | # If log_ndtr is not available, approximate as follows: 157 | r = special.erfc(x) 158 | if r == 0.0: 159 | # Using the Laurent series at infinity for the tail of the erfc function: 160 | # erfc(x) ~ exp(-x^2-.5/x^2+.625/x^4)/(x*pi^.5) 161 | # To verify in Mathematica: 162 | # Series[Log[Erfc[x]] + Log[x] + Log[Pi]/2 + x^2, {x, Infinity, 6}] 163 | return (-math.log(math.pi) / 2 - math.log(x) - x**2 - .5 * x**-2 + 164 | .625 * x**-4 - 37. / 24. * x**-6 + 353. / 64. * x**-8) 165 | else: 166 | return math.log(r) 167 | 168 | 169 | def _compute_delta(orders, rdp, eps): 170 | """Compute delta given a list of RDP values and target epsilon. 171 | 172 | Args: 173 | orders: An array (or a scalar) of orders. 174 | rdp: A list (or a scalar) of RDP guarantees. 175 | eps: The target epsilon. 176 | 177 | Returns: 178 | Pair of (delta, optimal_order). 179 | 180 | Raises: 181 | ValueError: If input is malformed. 182 | 183 | """ 184 | orders_vec = np.atleast_1d(orders) 185 | rdp_vec = np.atleast_1d(rdp) 186 | 187 | if len(orders_vec) != len(rdp_vec): 188 | raise ValueError("Input lists must have the same length.") 189 | 190 | deltas = np.exp((rdp_vec - eps) * (orders_vec - 1)) 191 | idx_opt = np.argmin(deltas) 192 | return min(deltas[idx_opt], 1.), orders_vec[idx_opt] 193 | 194 | 195 | def _compute_eps(orders, rdp, delta): 196 | """Compute epsilon given a list of RDP values and target delta. 197 | 198 | Args: 199 | orders: An array (or a scalar) of orders. 200 | rdp: A list (or a scalar) of RDP guarantees. 201 | delta: The target delta. 202 | 203 | Returns: 204 | Pair of (eps, optimal_order). 205 | 206 | Raises: 207 | ValueError: If input is malformed. 208 | 209 | """ 210 | orders_vec = np.atleast_1d(orders) 211 | rdp_vec = np.atleast_1d(rdp) 212 | 213 | if len(orders_vec) != len(rdp_vec): 214 | raise ValueError("Input lists must have the same length.") 215 | 216 | eps = rdp_vec - math.log(delta) / (orders_vec - 1) 217 | 218 | idx_opt = np.nanargmin(eps) # Ignore NaNs 219 | return eps[idx_opt], orders_vec[idx_opt] 220 | 221 | 222 | def _compute_rdp(q, sigma, alpha): 223 | """Compute RDP of the Sampled Gaussian mechanism at order alpha. 224 | 225 | Args: 226 | q: The sampling rate. 227 | sigma: The std of the additive Gaussian noise. 228 | alpha: The order at which RDP is computed. 229 | 230 | Returns: 231 | RDP at alpha, can be np.inf. 232 | """ 233 | if q == 0: 234 | return 0 235 | 236 | if q == 1.: 237 | return alpha / (2 * sigma**2) 238 | 239 | if np.isinf(alpha): 240 | return np.inf 241 | 242 | return _compute_log_a(q, sigma, alpha) / (alpha - 1) 243 | 244 | 245 | def compute_rdp(q, noise_multiplier, steps, orders): 246 | """Compute RDP of the Sampled Gaussian Mechanism. 247 | 248 | Args: 249 | q: The sampling rate. 250 | noise_multiplier: The ratio of the standard deviation of the Gaussian noise 251 | to the l2-sensitivity of the function to which it is added. 252 | steps: The number of steps. 253 | orders: An array (or a scalar) of RDP orders. 254 | 255 | Returns: 256 | The RDPs at all orders, can be np.inf. 257 | """ 258 | if np.isscalar(orders): 259 | rdp = _compute_rdp(q, noise_multiplier, orders) 260 | else: 261 | rdp = np.array([_compute_rdp(q, noise_multiplier, order) 262 | for order in orders]) 263 | 264 | return rdp * steps 265 | 266 | 267 | def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None): 268 | """Compute delta (or eps) for given eps (or delta) from RDP values. 269 | 270 | Args: 271 | orders: An array (or a scalar) of RDP orders. 272 | rdp: An array of RDP values. Must be of the same length as the orders list. 273 | target_eps: If not None, the epsilon for which we compute the corresponding 274 | delta. 275 | target_delta: If not None, the delta for which we compute the corresponding 276 | epsilon. Exactly one of target_eps and target_delta must be None. 277 | 278 | Returns: 279 | eps, delta, opt_order. 280 | 281 | Raises: 282 | ValueError: If target_eps and target_delta are messed up. 283 | """ 284 | if target_eps is None and target_delta is None: 285 | raise ValueError( 286 | "Exactly one out of eps and delta must be None. (Both are).") 287 | 288 | if target_eps is not None and target_delta is not None: 289 | raise ValueError( 290 | "Exactly one out of eps and delta must be None. (None is).") 291 | 292 | if target_eps is not None: 293 | delta, opt_order = _compute_delta(orders, rdp, target_eps) 294 | return target_eps, delta, opt_order 295 | else: 296 | eps, opt_order = _compute_eps(orders, rdp, target_delta) 297 | return eps, target_delta, opt_order 298 | 299 | 300 | def compute_rdp_from_ledger(ledger, orders): 301 | """Compute RDP of Sampled Gaussian Mechanism from ledger. 302 | 303 | Args: 304 | ledger: A formatted privacy ledger. 305 | orders: An array (or a scalar) of RDP orders. 306 | 307 | Returns: 308 | RDP at all orders, can be np.inf. 309 | """ 310 | total_rdp = 0 311 | for sample in ledger: 312 | # Compute equivalent z from l2_clip_bounds and noise stddevs in sample. 313 | # See https://arxiv.org/pdf/1812.06210.pdf for derivation of this formula. 314 | effective_z = sum([ 315 | (q.noise_stddev / q.l2_norm_bound)**-2 for q in sample.queries])**-0.5 316 | total_rdp += compute_rdp( 317 | sample.selection_probability, effective_z, 1, orders) 318 | return total_rdp 319 | -------------------------------------------------------------------------------- /pyvacy/analysis/subsampled.py: -------------------------------------------------------------------------------- 1 | import math 2 | from scipy.special import beta 3 | 4 | def binomial(x, y): 5 | return 1 / ((x + 1) * beta(x - y + 1, y + 1)) 6 | 7 | def epsilon(alpha, q, noise_multiplier): 8 | def _eps(_alpha): 9 | if _alpha == math.inf: 10 | return min(4 * (math.exp(_eps(2) - 1)), 2 * math.exp(_eps(2))) 11 | return _alpha / (2 * (noise_multiplier ** 2)) 12 | 13 | s = 0.0 14 | for j in range(3, alpha + 1): 15 | s += (q ** j) * binomial(alpha, j) * math.exp((j - 1) * _eps(j)) * min(2, (math.exp(_eps(math.inf)) - 1) ** j) 16 | 17 | return (1 / (alpha - 1)) * math.log(1 + (q ** 2) * binomial(alpha, 2) * min(4 * (math.exp(_eps(2)) - 1), math.exp(2) * min(2, math.exp(_eps(math.inf)) - 1) ** j) + s) 18 | 19 | -------------------------------------------------------------------------------- /pyvacy/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .dp_optimizer import * 2 | -------------------------------------------------------------------------------- /pyvacy/optim/dp_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | from torch.nn.utils.clip_grad import clip_grad_norm_ 4 | from torch.distributions.normal import Normal 5 | from torch.optim import SGD, Adam, Adagrad, RMSprop 6 | 7 | def make_optimizer_class(cls): 8 | class DPOptimizerClass(cls): 9 | def __init__(self, l2_norm_clip, noise_multiplier, minibatch_size, microbatch_size, *args, **kwargs): 10 | super(DPOptimizerClass, self).__init__(*args, **kwargs) 11 | 12 | self.l2_norm_clip = l2_norm_clip 13 | self.noise_multiplier = noise_multiplier 14 | self.microbatch_size = microbatch_size 15 | self.minibatch_size = minibatch_size 16 | 17 | for group in self.param_groups: 18 | group['accum_grads'] = [torch.zeros_like(param.data) if param.requires_grad else None for param in group['params']] 19 | 20 | def zero_microbatch_grad(self): 21 | super(DPOptimizerClass, self).zero_grad() 22 | 23 | def microbatch_step(self): 24 | total_norm = 0. 25 | for group in self.param_groups: 26 | for param in group['params']: 27 | if param.requires_grad: 28 | total_norm += param.grad.data.norm(2).item() ** 2. 29 | total_norm = total_norm ** .5 30 | clip_coef = min(self.l2_norm_clip / (total_norm + 1e-6), 1.) 31 | 32 | for group in self.param_groups: 33 | for param, accum_grad in zip(group['params'], group['accum_grads']): 34 | if param.requires_grad: 35 | accum_grad.add_(param.grad.data.mul(clip_coef)) 36 | 37 | def zero_grad(self): 38 | for group in self.param_groups: 39 | for accum_grad in group['accum_grads']: 40 | if accum_grad is not None: 41 | accum_grad.zero_() 42 | 43 | def step(self, *args, **kwargs): 44 | for group in self.param_groups: 45 | for param, accum_grad in zip(group['params'], group['accum_grads']): 46 | if param.requires_grad: 47 | param.grad.data = accum_grad.clone() 48 | param.grad.data.add_(self.l2_norm_clip * self.noise_multiplier * torch.randn_like(param.grad.data)) 49 | param.grad.data.mul_(self.microbatch_size / self.minibatch_size) 50 | super(DPOptimizerClass, self).step(*args, **kwargs) 51 | 52 | return DPOptimizerClass 53 | 54 | DPAdam = make_optimizer_class(Adam) 55 | DPAdagrad = make_optimizer_class(Adagrad) 56 | DPSGD = make_optimizer_class(SGD) 57 | DPRMSprop = make_optimizer_class(RMSprop) 58 | 59 | -------------------------------------------------------------------------------- /pyvacy/sampling/__init__.py: -------------------------------------------------------------------------------- 1 | from .batch_samplers import * 2 | -------------------------------------------------------------------------------- /pyvacy/sampling/batch_samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, TensorDataset 3 | import numpy as np 4 | 5 | class IIDBatchSampler: 6 | def __init__(self, dataset, minibatch_size, iterations): 7 | self.length = len(dataset) 8 | self.minibatch_size = minibatch_size 9 | self.iterations = iterations 10 | 11 | def __iter__(self): 12 | for _ in range(self.iterations): 13 | indices = np.where(torch.rand(self.length) < (self.minibatch_size / self.length))[0] 14 | if indices.size > 0: 15 | yield indices 16 | 17 | def __len__(self): 18 | return self.iterations 19 | 20 | class EquallySizedAndIndependentBatchSampler: 21 | def __init__(self, dataset, minibatch_size, iterations): 22 | self.length = len(dataset) 23 | self.minibatch_size = minibatch_size 24 | self.iterations = iterations 25 | 26 | def __iter__(self): 27 | for _ in range(self.iterations): 28 | yield np.random.choice(self.length, self.minibatch_size) 29 | 30 | def __len__(self): 31 | return self.iterations 32 | 33 | def get_data_loaders(minibatch_size, microbatch_size, iterations, drop_last=True): 34 | def minibatch_loader(dataset): 35 | return DataLoader( 36 | dataset, 37 | batch_sampler=IIDBatchSampler(dataset, minibatch_size, iterations) 38 | ) 39 | 40 | def microbatch_loader(minibatch): 41 | return DataLoader( 42 | minibatch, 43 | batch_size=microbatch_size, 44 | # Using less data than allowed will yield no worse of a privacy guarantee, 45 | # and sometimes processing uneven batches can cause issues during training, e.g. when 46 | # using BatchNorm (although BatchNorm in particular should be analyzed seperately 47 | # for privacy, since it's maintaining internal information about forward passes 48 | # over time without noise addition.) 49 | # Use seperate IIDBatchSampler class if a more granular training process is needed. 50 | drop_last=drop_last, 51 | ) 52 | 53 | return minibatch_loader, microbatch_loader 54 | 55 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.1 2 | torchvision==0.9.1 3 | tensorboard==2.5.0 4 | nni==2.2 5 | pillow==8.2.0 6 | pandas 7 | matplotlib -------------------------------------------------------------------------------- /transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class Transformer(nn.Module): 7 | def __init__(self, input_dim=10, output_dim=2, hidden_dim=64, num_layers=3, nhead=4, dropout=0.2): 8 | super().__init__() 9 | self.bn = nn.BatchNorm1d(input_dim) 10 | self.length = input_dim // 3 11 | self.hidden_dim = hidden_dim 12 | self.fc1 = nn.Linear(3, hidden_dim) 13 | encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead, dim_feedforward=128, 14 | dropout=dropout, activation='gelu') 15 | self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) 16 | self.fc2 = nn.Linear(hidden_dim * self.length, hidden_dim) 17 | self.fc3 = nn.Linear(hidden_dim, output_dim) 18 | 19 | def forward(self, x): 20 | x = self.bn(x) 21 | x1, x2, x3 = x[:, :self.length].unsqueeze(2), x[:, self.length:self.length*2].unsqueeze(2), \ 22 | x[:, self.length*2:].unsqueeze(2) 23 | x = torch.cat([x1, x2, x3], dim=-1) 24 | x = F.gelu(self.fc1(x).permute(1, 0, 2)) 25 | x = self.transformer_encoder(x) 26 | x = x.permute(1, 0, 2).contiguous() 27 | x = x.view(-1, self.hidden_dim * self.length) 28 | x = F.relu(self.fc2(x)) 29 | x = self.fc3(x) 30 | return x 31 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | from torch.nn import init 5 | 6 | 7 | def seed_worker(worker_id): 8 | worker_seed = torch.initial_seed() % 2 ** 32 9 | np.random.seed(worker_seed) 10 | random.seed(worker_seed) 11 | 12 | 13 | def weight_init(m): 14 | classname = m.__class__.__name__ 15 | if classname.startswith('Conv') or classname == 'Linear': 16 | if getattr(m, 'bias', None) is not None: 17 | init.constant_(m.bias, 0.0) 18 | if getattr(m, 'weight', None) is not None: 19 | init.xavier_normal_(m.weight) 20 | elif 'Norm' in classname: 21 | if getattr(m, 'weight', None) is not None: 22 | m.weight.data.fill_(1) 23 | if getattr(m, 'bias', None) is not None: 24 | m.bias.data.zero_() 25 | 26 | 27 | def get_model(model_type, num_cls, input_dim): 28 | if model_type == "resnet18": 29 | from cifar10_models import resnet18 30 | model = resnet18(pretrained=False, num_classes=num_cls) 31 | elif model_type == "vgg16": 32 | from cifar10_models import vgg16 33 | model = vgg16(pretrained=False, num_classes=num_cls) 34 | elif model_type == "densenet121": 35 | from cifar10_models import densenet121 36 | model = densenet121(pretrained=False, num_classes=num_cls) 37 | elif model_type == "columnfc": 38 | from models import ColumnFC 39 | model = ColumnFC(input_dim=input_dim, output_dim=num_cls) 40 | elif model_type == "mia_fc": 41 | from models import MIAFC 42 | model = MIAFC(input_dim=num_cls, output_dim=2) 43 | elif model_type == "transformer": 44 | from transformer import Transformer 45 | model = Transformer(input_dim=num_cls, output_dim=2) 46 | else: 47 | print(model_type) 48 | raise ValueError 49 | return model 50 | 51 | 52 | def get_optimizer(optimizer_name, parameters, lr, weight_decay=0): 53 | if optimizer_name == "sgd": 54 | optimizer = torch.optim.SGD(parameters, lr=lr, momentum=0.9, weight_decay=weight_decay) 55 | elif optimizer_name == "adam": 56 | optimizer = torch.optim.Adam(parameters, lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay) 57 | elif optimizer_name == "": 58 | optimizer = None 59 | # print("Do not use optimizer.") 60 | else: 61 | print(optimizer_name) 62 | raise ValueError 63 | return optimizer --------------------------------------------------------------------------------