├── LICENSE ├── README.md ├── aggregation.py ├── architectures ├── purchase.py └── svhn.py ├── datasets └── purchase │ ├── data1.npz │ ├── data2.npz │ ├── dataloader.py │ ├── datasetfile │ └── prepare_data.py ├── distribution.py ├── example-scripts └── purchase-sharding │ ├── README.txt │ ├── data.sh │ ├── init.sh │ ├── predict.sh │ └── train.sh ├── sharded.py ├── sisa.py └── time.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 CleverHans Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Machine Unlearning with SISA 2 | ### Lucas Bourtoule, Varun Chandrasekaran, Christopher Choquette-Choo, Hengrui Jia, Adelin Travers, Baiwu Zhang, David Lie, Nicolas Papernot 3 | This repository contains the core code used in the SISA experiments of our [Machine Unlearning](https://arxiv.org/abs/1912.03817) paper along with some example scripts. 4 | 5 | You can start runing experiments by having a look at the readme in the purchase example folder at ``example-scripts/purchase-sharding``. 6 | 7 | ``sisa.py`` is the script that trains a given shard. It should be run as many times as the number of shards. 8 | 9 | ## Citing this work 10 | 11 | If you use this repository for academic research, you are highly encouraged 12 | (though not required) to cite our paper: 13 | 14 | ``` 15 | @inproceedings{bourtoule2021machine, 16 | title={Machine Unlearning}, 17 | author={Lucas Bourtoule and Varun Chandrasekaran and Christopher Choquette-Choo and Hengrui Jia and Adelin Travers and Baiwu Zhang and David Lie and Nicolas Papernot}, 18 | booktitle={Proceedings of the 42nd IEEE Symposium on Security and Privacy}, 19 | year={2021} 20 | } 21 | ``` 22 | -------------------------------------------------------------------------------- /aggregation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import os 4 | import importlib 5 | 6 | import argparse 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument( 10 | "--strategy", default="uniform", help="Voting strategy, default uniform" 11 | ) 12 | parser.add_argument("--container", help="Name of the container") 13 | parser.add_argument("--shards", type=int, default=1, help="Number of shards, default 1") 14 | parser.add_argument( 15 | "--dataset", 16 | default="datasets/purchase/datasetfile", 17 | help="Location of the datasetfile, default datasets/purchase/datasetfile", 18 | ) 19 | parser.add_argument( 20 | "--baseline", type=int, help="Use only the specified shard (lone shard baseline)" 21 | ) 22 | parser.add_argument("--label", default="latest", help="Label, default latest") 23 | args = parser.parse_args() 24 | 25 | # Load dataset metadata. 26 | with open(args.dataset) as f: 27 | datasetfile = json.loads(f.read()) 28 | dataloader = importlib.import_module( 29 | ".".join(args.dataset.split("/")[:-1] + [datasetfile["dataloader"]]) 30 | ) 31 | 32 | # Output files used for the vote. 33 | if args.baseline != None: 34 | filenames = ["shard-{}:{}.npy".format(args.baseline, args.label)] 35 | else: 36 | filenames = ["shard-{}:{}.npy".format(i, args.label) for i in range(args.shards)] 37 | 38 | # Concatenate output files. 39 | outputs = [] 40 | for filename in filenames: 41 | outputs.append( 42 | np.load( 43 | os.path.join("containers/{}/outputs".format(args.container), filename), 44 | allow_pickle=True, 45 | ) 46 | ) 47 | outputs = np.array(outputs) 48 | 49 | # Compute weight vector based on given strategy. 50 | if args.strategy == "uniform": 51 | weights = ( 52 | 1 / outputs.shape[0] * np.ones((outputs.shape[0],)) 53 | ) # pylint: disable=unsubscriptable-object 54 | elif args.strategy.startswith("models:"): 55 | models = np.array(args.strategy.split(":")[1].split(",")).astype(int) 56 | weights = np.zeros((outputs.shape[0],)) # pylint: disable=unsubscriptable-object 57 | weights[models] = 1 / models.shape[0] # pylint: disable=unsubscriptable-object 58 | elif args.strategy == "proportional": 59 | split = np.load( 60 | "containers/{}/splitfile.npy".format(args.container), allow_pickle=True 61 | ) 62 | weights = np.array([shard.shape[0] for shard in split]) 63 | 64 | # Tensor contraction of outputs and weights (on the shard dimension). 65 | votes = np.argmax( 66 | np.tensordot(weights.reshape(1, weights.shape[0]), outputs, axes=1), axis=2 67 | ).reshape( 68 | (outputs.shape[1],) 69 | ) # pylint: disable=unsubscriptable-object 70 | 71 | # Load labels. 72 | _, labels = dataloader.load(np.arange(datasetfile["nb_test"]), category="test") 73 | 74 | # Compute and print accuracy. 75 | accuracy = ( 76 | np.where(votes == labels)[0].shape[0] / outputs.shape[1] 77 | ) # pylint: disable=unsubscriptable-object 78 | print(accuracy) 79 | -------------------------------------------------------------------------------- /architectures/purchase.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Module, Linear 2 | from torch.nn.functional import tanh 3 | 4 | class Model(Module): 5 | def __init__(self, input_shape, nb_classes, *args, **kwargs): 6 | super(Model, self).__init__() 7 | self.fc1 = Linear(input_shape[0], 128) 8 | self.fc2 = Linear(128, nb_classes) 9 | 10 | def forward(self, x): 11 | x = self.fc1(x) 12 | x = tanh(x) 13 | x = self.fc2(x) 14 | 15 | return x -------------------------------------------------------------------------------- /architectures/svhn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import sys 7 | import numpy as np 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True) 11 | 12 | def conv_init(m): 13 | classname = m.__class__.__name__ 14 | if classname.find('Conv') != -1: 15 | init.xavier_uniform(m.weight, gain=np.sqrt(2)) 16 | init.constant(m.bias, 0) 17 | elif classname.find('BatchNorm') != -1: 18 | init.constant(m.weight, 1) 19 | init.constant(m.bias, 0) 20 | 21 | class wide_basic(nn.Module): 22 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 23 | super(wide_basic, self).__init__() 24 | self.bn1 = nn.BatchNorm2d(in_planes) 25 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 26 | self.dropout = nn.Dropout(p=dropout_rate) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True) 29 | 30 | self.shortcut = nn.Sequential() 31 | if stride != 1 or in_planes != planes: 32 | self.shortcut = nn.Sequential( 33 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 34 | ) 35 | 36 | def forward(self, x): 37 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 38 | out = self.conv2(F.relu(self.bn2(out))) 39 | out += self.shortcut(x) 40 | 41 | return out 42 | 43 | class Model(nn.Module): 44 | def __init__(self, input_shape, nb_classes, *args, **kwargs): 45 | depth = 1 46 | widen_factor = 1 47 | dropout_rate = kwargs['dropout_rate'] if 'dropout_rate' in kwargs.keys() else 0.4 48 | num_classes = nb_classes 49 | 50 | super(Model, self).__init__() 51 | self.in_planes = 16 52 | 53 | # assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4' 54 | n = int((depth-4)/6) 55 | k = widen_factor 56 | 57 | print('| Wide-Resnet %dx%d' %(depth, k)) 58 | print(n) 59 | nStages = [16, 16*k, 32*k, 64*k] 60 | 61 | self.conv1 = conv3x3(3,nStages[0]) 62 | self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1) 63 | self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2) 64 | self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2) 65 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 66 | self.linear = nn.Linear(nStages[3], num_classes) 67 | 68 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 69 | strides = [stride] + [1]*(num_blocks-1) 70 | layers = [] 71 | 72 | for stride in strides: 73 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 74 | self.in_planes = planes 75 | 76 | return nn.Sequential(*layers) 77 | 78 | def forward(self, x): 79 | out = self.conv1(x) 80 | out = self.layer1(out) 81 | out = self.layer2(out) 82 | out = self.layer3(out) 83 | out = F.relu(self.bn1(out)) 84 | out = F.avg_pool2d(out, 8) 85 | out = out.view(out.size(0), -1) 86 | out = self.linear(out) 87 | 88 | return out -------------------------------------------------------------------------------- /datasets/purchase/data1.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cleverhans-lab/machine-unlearning/50a3ada6c777f6882a1764d74a4978bfd61ec6f9/datasets/purchase/data1.npz -------------------------------------------------------------------------------- /datasets/purchase/data2.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cleverhans-lab/machine-unlearning/50a3ada6c777f6882a1764d74a4978bfd61ec6f9/datasets/purchase/data2.npz -------------------------------------------------------------------------------- /datasets/purchase/dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | pwd = os.path.dirname(os.path.realpath(__file__)) 5 | 6 | train_data = np.load(os.path.join(pwd, 'purchase2_train.npy'), allow_pickle=True) 7 | test_data = np.load(os.path.join(pwd, 'purchase2_test.npy'), allow_pickle=True) 8 | 9 | train_data = train_data.reshape((1,))[0] 10 | test_data = test_data.reshape((1,))[0] 11 | 12 | X_train = train_data['X'].astype(np.float32) 13 | X_test = test_data['X'].astype(np.float32) 14 | y_train = train_data['y'].astype(np.int64) 15 | y_test = test_data['y'].astype(np.int64) 16 | 17 | def load(indices, category='train'): 18 | if category == 'train': 19 | return X_train[indices], y_train[indices] 20 | elif category == 'test': 21 | return X_test[indices], y_test[indices] -------------------------------------------------------------------------------- /datasets/purchase/datasetfile: -------------------------------------------------------------------------------- 1 | { 2 | "nb_train": 280368, 3 | "nb_test": 31151, 4 | "input_shape": [600], 5 | "nb_classes": 2, 6 | "dataloader": "dataloader" 7 | } -------------------------------------------------------------------------------- /datasets/purchase/prepare_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from sklearn.cluster import KMeans 4 | from sklearn.model_selection import train_test_split 5 | from scipy.sparse import load_npz 6 | 7 | 8 | data = np.concatenate([load_npz('data1.npz').toarray(), load_npz('data2.npz').toarray()]).astype(int) 9 | 10 | num_class = 2 11 | 12 | if not os.path.exists(f'{num_class}_kmeans.npy'): 13 | kmeans = KMeans(n_clusters=num_class, random_state=0).fit(data) 14 | label = kmeans.labels_ 15 | np.save(f'{num_class}_kmeans.npy', label) 16 | else: 17 | label = np.load(f'{num_class}_kmeans.npy') 18 | 19 | if not os.path.exists(f'purchase{num_class}_train.npy'): 20 | X_train, X_test, y_train, y_test = train_test_split(data, label, test_size=0.2) 21 | np.save(f'purchase{num_class}_train.npy', {'X': X_train, 'y': y_train}) 22 | np.save(f'purchase{num_class}_test.npy', {'X': X_test, 'y': y_test}) 23 | -------------------------------------------------------------------------------- /distribution.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import os 4 | 5 | import argparse 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--shards", 10 | default=None, 11 | type=int, 12 | help="Split the dataset in the given number of shards in an optimized manner (PLS-GAP partitionning) according to the given distribution, create the corresponding splitfile", 13 | ) 14 | parser.add_argument( 15 | "--requests", 16 | default=None, 17 | type=int, 18 | help="Generate the given number of unlearning requests according to the given distribution and apply them directly to the splitfile", 19 | ) 20 | parser.add_argument( 21 | "--distribution", 22 | default="uniform", 23 | help="Assumed distribution when used with --shards, sampling distribution when used with --requests. Use 'reset' to reset requestfile, default uniform", 24 | ) 25 | parser.add_argument("--container", default="default", help="Name of the container") 26 | parser.add_argument( 27 | "--dataset", 28 | default="datasets/purchase/datasetfile", 29 | help="Location of the datasetfile, default datasets/purchase/datasetfile", 30 | ) 31 | parser.add_argument("--label", default="latest", help="Label, default latest") 32 | args = parser.parse_args() 33 | 34 | # Load dataset metadata. 35 | with open(args.dataset) as f: 36 | datasetfile = json.loads(f.read()) 37 | 38 | if args.shards != None: 39 | # If distribution is uniform, split without optimizing. 40 | if args.distribution == "uniform": 41 | partition = np.split( 42 | np.arange(0, datasetfile["nb_train"]), 43 | [ 44 | t * (datasetfile["nb_train"] // args.shards) 45 | for t in range(1, args.shards) 46 | ], 47 | ) 48 | np.save("containers/{}/splitfile.npy".format(args.container), partition) 49 | requests = np.array([[] for _ in range(args.shards)]) 50 | np.save( 51 | "containers/{}/requestfile:{}.npy".format(args.container, args.label), 52 | requests, 53 | ) 54 | 55 | # Else run PLS-GAP algorithm to find a low cost split. 56 | else: 57 | 58 | def mass(index): 59 | if args.distribution.split(":")[0] == "exponential": 60 | lbd = ( 61 | float(args.distribution.split(":")[1]) 62 | if len(args.distribution.split(":")) > 1 63 | else -np.log(0.05) / datasetfile["nb_train"] 64 | ) 65 | return np.exp(-lbd * index) - np.exp(-lbd * (index + 1)) 66 | if args.distribution.split(":")[0] == "pareto": 67 | a = ( 68 | float(args.distribution.split(":")[1]) 69 | if len(args.distribution.split(":")) > 1 70 | else 1.16 71 | ) 72 | return a / ((index + 1) ** (a + 1)) 73 | 74 | if args.shards != None: 75 | # Initialize queue and partition. 76 | weights = mass(np.arange(0, datasetfile["nb_train"])) 77 | indices = np.argsort(weights) 78 | queue = np.array([weights[indices], np.ones(weights.shape)]).transpose() 79 | partition = [np.array([index]) for index in indices] 80 | 81 | # Put all points in the top queue. 82 | bottom_queue = queue.shape[0] # pylint: disable=unsubscriptable-object 83 | lim = ( 84 | int(float(args.algo.split(":")[1]) * datasetfile["nb_train"]) 85 | if len(args.algo.split(":")) > 1 86 | else int(0.01 * datasetfile["nb_train"]) 87 | ) 88 | 89 | for _ in range(datasetfile["nb_train"] - args.shards): 90 | # Fetch top 2 clusters and merge them. 91 | w1 = queue[0] 92 | w2 = queue[1] 93 | 94 | l1 = partition[0] 95 | l2 = partition[1] 96 | 97 | partition = partition[2:] 98 | queue = queue[2:] 99 | bottom_queue -= 2 100 | 101 | merged_weight = w1 + w2 102 | 103 | # If merged cluster is smaller in number of points than the limit, insert it in top queue. 104 | if merged_weight[1] < lim: 105 | # Top queue is ordered first by number of points (weight[1]) and second by cost (weight[0]). 106 | offset_array = np.where(queue[:bottom_queue, 1] >= merged_weight[1]) 107 | limit_array = np.where(queue[:bottom_queue, 1] > merged_weight[1]) 108 | offset = ( 109 | offset_array[0][0] 110 | if offset_array[0].shape[0] > 0 111 | else bottom_queue 112 | ) 113 | limit = ( 114 | limit_array[0][0] 115 | if limit_array[0].shape[0] > 0 116 | else bottom_queue 117 | ) 118 | position_array = np.where( 119 | queue[offset:limit][:, 0] >= merged_weight[0] 120 | ) 121 | position = ( 122 | position_array[0][0] 123 | if position_array[0].shape[0] > 0 124 | else bottom_queue 125 | ) 126 | bottom_queue += 1 127 | 128 | # Otherwise insert it in the bottom queue. 129 | else: 130 | # Bottom queue is ordered by cost only. 131 | position_array = np.where( 132 | queue[bottom_queue:][:, 0] >= merged_weight[0] 133 | ) 134 | position = ( 135 | position_array[0][0] 136 | if position_array[0].shape[0] > 0 137 | else queue.shape[0] 138 | ) 139 | 140 | # Actual insertion. 141 | queue = np.insert(queue, position, merged_weight, axis=0) 142 | partition = ( 143 | partition[:position] 144 | + [np.concatenate((l1, l2))] 145 | + partition[position:] 146 | ) 147 | 148 | # Generate splitfile and empty request file. 149 | np.save("containers/{}/splitfile.npy".format(args.container), partition) 150 | requests = np.array([[] for _ in range(partition.shape[0])]) 151 | np.save( 152 | "containers/{}/requestfile:{}.npy".format(args.container, args.label), 153 | requests, 154 | ) 155 | 156 | if args.requests != None: 157 | if args.distribution == "reset": 158 | requests = np.array([[] for _ in range(partition.shape[0])]) 159 | np.save( 160 | "containers/{}/requestfile:{}.npy".format(args.container, args.label), 161 | requests, 162 | ) 163 | else: 164 | # Load splitfile. 165 | partition = np.load( 166 | "containers/{}/splitfile.npy".format(args.container), allow_pickle=True 167 | ) 168 | 169 | # Randomly select points to be removed with given distribution at the dataset scale. 170 | if args.distribution.split(":")[0] == "exponential": 171 | lbd = ( 172 | float(args.distribution.split(":")[1]) 173 | if len(args.distribution.split(":")) > 1 174 | else -np.log(0.05) / datasetfile["nb_train"] 175 | ) 176 | all_requests = np.random.exponential(1 / lbd, (args.requests,)) 177 | if args.distribution.split(":")[0] == "pareto": 178 | a = ( 179 | float(args.distribution.split(":")[1]) 180 | if len(args.distribution.split(":")) > 1 181 | else 1.16 182 | ) 183 | all_requests = np.random.pareto(a, (args.requests,)) 184 | else: 185 | all_requests = np.random.randint(0, datasetfile["nb_train"], args.requests) 186 | 187 | requests = [] 188 | # Divide up the new requests among the shards. 189 | for shard in range(partition.shape[0]): 190 | requests.append(np.intersect1d(partition[shard], all_requests)) 191 | 192 | # Update requestfile. 193 | np.save( 194 | "containers/{}/requestfile:{}.npy".format(args.container, args.label), 195 | np.array(requests), 196 | ) 197 | -------------------------------------------------------------------------------- /example-scripts/purchase-sharding/README.txt: -------------------------------------------------------------------------------- 1 | Before the first time running experiments on purchase dataset or to customize number of labels used, please run `prepare_data.py` at `machine-unlearning/datasets/purchase`. 2 | 3 | The following scripts allow to run a sharding experiment on purchase dataset. 4 | 5 | 1- Create a container with a specified number of shards: 6 | init.sh 5 7 | 8 | 2- Train the shards in the container: 9 | train.sh 5 10 | 11 | 3- Compute shard predictions: 12 | predict.sh 5 13 | 14 | 4- Retrieve experimental data as a CSV: 15 | data.sh 5 16 | -------------------------------------------------------------------------------- /example-scripts/purchase-sharding/data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eou pipefail 4 | IFS=$'\n\t' 5 | 6 | shards=$1 7 | 8 | if [[ ! -f general-report.csv ]]; then 9 | echo "nb_shards,nb_requests,accuracy,retraining_time" > general-report.csv 10 | fi 11 | 12 | for j in {0..15}; do 13 | r=$((${j}*${shards}/5)) 14 | acc=$(python aggregation.py --strategy uniform --container "${shards}" --shards "${shards}" --dataset datasets/purchase/datasetfile --label "${r}") 15 | cat containers/"${shards}"/times/shard-*:"${r}".time > "containers/${shards}/times/times" 16 | time=$(python time.py --container "${shards}" | awk -F ',' '{print $1}') 17 | echo "${shards},${r},${acc},${time}" >> general-report.csv 18 | done 19 | -------------------------------------------------------------------------------- /example-scripts/purchase-sharding/init.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eou pipefail 4 | IFS=$'\n\t' 5 | 6 | shards=$1 7 | 8 | if [[ ! -d "containers/${shards}" ]] ; then 9 | mkdir "containers/${shards}" 10 | mkdir "containers/${shards}/cache" 11 | mkdir "containers/${shards}/times" 12 | mkdir "containers/${shards}/outputs" 13 | echo 0 > "containers/${shards}/times/null.time" 14 | fi 15 | 16 | python distribution.py --shards "${shards}" --distribution uniform --container "${shards}" --dataset datasets/purchase/datasetfile --label 0 17 | 18 | for j in {1..15}; do 19 | r=$((${j}*${shards}/5)) 20 | python distribution.py --requests "${r}" --distribution uniform --container "${shards}" --dataset datasets/purchase/datasetfile --label "${r}" 21 | done 22 | -------------------------------------------------------------------------------- /example-scripts/purchase-sharding/predict.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eou pipefail 4 | IFS=$'\n\t' 5 | 6 | shards=$1 7 | 8 | for i in $(seq 0 "$((${shards}-1))"); do 9 | for j in {0..15}; do 10 | echo "shard: $((${i}+1))/${shards}, requests: $((${j}+1))/16" 11 | r=$((${j}*${shards}/5)) 12 | python sisa.py --model purchase --test --dataset datasets/purchase/datasetfile --label "${r}" --batch_size 16 --container "${shards}" --shard "${i}" 13 | done 14 | done 15 | -------------------------------------------------------------------------------- /example-scripts/purchase-sharding/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -eou pipefail 4 | IFS=$'\n\t' 5 | 6 | shards=$1 7 | 8 | for i in $(seq 0 "$((${shards}-1))"); do 9 | for j in {0..15}; do 10 | echo "shard: $((${i}+1))/${shards}, requests: $((${j}+1))/16" 11 | r=$((${j}*${shards}/5)) 12 | python sisa.py --model purchase --train --slices 1 --dataset datasets/purchase/datasetfile --label "${r}" --epochs 20 --batch_size 16 --learning_rate 0.001 --optimizer sgd --chkpt_interval 1 --container "${shards}" --shard "${i}" 13 | done 14 | done 15 | -------------------------------------------------------------------------------- /sharded.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from hashlib import sha256 3 | import importlib 4 | import json 5 | 6 | def sizeOfShard(container, shard): 7 | ''' 8 | Returns the size (in number of points) of the shard before any unlearning request. 9 | ''' 10 | shards = np.load('containers/{}/splitfile.npy'.format(container), allow_pickle=True) 11 | 12 | return shards[shard].shape[0] 13 | 14 | def realSizeOfShard(container, label, shard): 15 | ''' 16 | Returns the actual size of the shard (including unlearning requests). 17 | ''' 18 | shards = np.load('containers/{}/splitfile.npy'.format(container), allow_pickle=True) 19 | requests = np.load('containers/{}/requestfile:{}.npy'.format(container, label), allow_pickle=True) 20 | 21 | return shards[shard].shape[0] - requests[shard].shape[0] 22 | 23 | def getShardHash(container, label, shard, until=None): 24 | ''' 25 | Returns a hash of the indices of the points in the shard lower than until 26 | that are not in the requests (separated by :). 27 | ''' 28 | shards = np.load('containers/{}/splitfile.npy'.format(container), allow_pickle=True) 29 | requests = np.load('containers/{}/requestfile:{}.npy'.format(container, label), allow_pickle=True) 30 | 31 | if until == None: 32 | until = shards[shard].shape[0] 33 | indices = np.setdiff1d(shards[shard][:until], requests[shard]) 34 | string_of_indices = ':'.join(indices.astype(str)) 35 | return sha256(string_of_indices.encode()).hexdigest() 36 | 37 | def fetchShardBatch(container, label, shard, batch_size, dataset, offset=0, until=None): 38 | ''' 39 | Generator returning batches of points in the shard that are not in the requests 40 | with specified batch_size from the specified dataset 41 | optionnally located between offset and until (slicing). 42 | ''' 43 | shards = np.load('containers/{}/splitfile.npy'.format(container), allow_pickle=True) 44 | requests = np.load('containers/{}/requestfile:{}.npy'.format(container, label), allow_pickle=True) 45 | 46 | with open(dataset) as f: 47 | datasetfile = json.loads(f.read()) 48 | dataloader = importlib.import_module('.'.join(dataset.split('/')[:-1] + [datasetfile['dataloader']])) 49 | if until == None or until > shards[shard].shape[0]: 50 | until = shards[shard].shape[0] 51 | 52 | limit = offset 53 | while limit <= until - batch_size: 54 | limit += batch_size 55 | indices = np.setdiff1d(shards[shard][limit-batch_size:limit], requests[shard]) 56 | yield dataloader.load(indices) 57 | if limit < until: 58 | indices = np.setdiff1d(shards[shard][limit:until], requests[shard]) 59 | yield dataloader.load(indices) 60 | 61 | def fetchTestBatch(dataset, batch_size): 62 | ''' 63 | Generator returning batches of points from the specified test dataset 64 | with specified batch_size. 65 | ''' 66 | with open(dataset) as f: 67 | datasetfile = json.loads(f.read()) 68 | dataloader = importlib.import_module('.'.join(dataset.split('/')[:-1] + [datasetfile['dataloader']])) 69 | 70 | limit = 0 71 | while limit <= datasetfile['nb_test'] - batch_size: 72 | limit += batch_size 73 | yield dataloader.load(np.arange(limit - batch_size, limit), category='test') 74 | if limit < datasetfile['nb_test']: 75 | yield dataloader.load(np.arange(limit, datasetfile['nb_test']), category='test') 76 | -------------------------------------------------------------------------------- /sisa.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.nn import CrossEntropyLoss 4 | from torch.optim import Adam, SGD 5 | from torch.nn.functional import one_hot 6 | from sharded import sizeOfShard, getShardHash, fetchShardBatch, fetchTestBatch 7 | import os 8 | from glob import glob 9 | from time import time 10 | import json 11 | 12 | import argparse 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument( 16 | "--model", default="purchase", help="Architecture to use, default purchase" 17 | ) 18 | 19 | parser.add_argument( 20 | "--train", action="store_true", help="Perform SISA training on the shard" 21 | ) 22 | parser.add_argument("--test", action="store_true", help="Compute shard predictions") 23 | 24 | parser.add_argument( 25 | "--epochs", 26 | default=20, 27 | type=int, 28 | help="Train for the specified number of epochs, default 20", 29 | ) 30 | parser.add_argument( 31 | "--batch_size", 32 | default=16, 33 | type=int, 34 | help="Size of the batches, relevant for both train and test, default 16", 35 | ) 36 | parser.add_argument( 37 | "--dropout_rate", 38 | default=0.4, 39 | type=float, 40 | help="Dropout rate, if relevant, default 0.4", 41 | ) 42 | parser.add_argument( 43 | "--learning_rate", default=0.001, type=float, help="Learning rate, default 0.001" 44 | ) 45 | 46 | parser.add_argument("--optimizer", default="sgd", help="Optimizer, default sgd") 47 | 48 | parser.add_argument( 49 | "--output_type", 50 | default="argmax", 51 | help="Type of outputs to be used in aggregation, can be either argmax or softmax, default argmax", 52 | ) 53 | 54 | parser.add_argument("--container", help="Name of the container") 55 | parser.add_argument("--shard", type=int, help="Index of the shard to train/test") 56 | parser.add_argument( 57 | "--slices", default=1, type=int, help="Number of slices to use, default 1" 58 | ) 59 | parser.add_argument( 60 | "--dataset", 61 | default="datasets/purchase/datasetfile", 62 | help="Location of the datasetfile, default datasets/purchase/datasetfile", 63 | ) 64 | 65 | parser.add_argument( 66 | "--chkpt_interval", 67 | default=1, 68 | type=int, 69 | help="Interval (in epochs) between two chkpts, -1 to disable chackpointing, default 1", 70 | ) 71 | parser.add_argument( 72 | "--label", 73 | default="latest", 74 | help="Label to be used on simlinks and outputs, default latest", 75 | ) 76 | args = parser.parse_args() 77 | 78 | # Import the architecture. 79 | from importlib import import_module 80 | 81 | model_lib = import_module("architectures.{}".format(args.model)) 82 | 83 | # Retrive dataset metadata. 84 | with open(args.dataset) as f: 85 | datasetfile = json.loads(f.read()) 86 | input_shape = tuple(datasetfile["input_shape"]) 87 | nb_classes = datasetfile["nb_classes"] 88 | 89 | # Use GPU if available. 90 | device = torch.device( 91 | "cuda:0" if torch.cuda.is_available() else "cpu" 92 | ) # pylint: disable=no-member 93 | 94 | # Instantiate model and send to selected device. 95 | model = model_lib.Model(input_shape, nb_classes, dropout_rate=args.dropout_rate) 96 | model.to(device) 97 | 98 | # Instantiate loss and optimizer. 99 | loss_fn = CrossEntropyLoss() 100 | if args.optimizer == "adam": 101 | optimizer = Adam(model.parameters(), lr=args.learning_rate) 102 | elif args.optimizer == "sgd": 103 | optimizer = SGD(model.parameters(), lr=args.learning_rate) 104 | else: 105 | raise "Unsupported optimizer" 106 | 107 | if args.train: 108 | shard_size = sizeOfShard(args.container, args.shard) 109 | slice_size = shard_size // args.slices 110 | avg_epochs_per_slice = ( 111 | 2 * args.slices / (args.slices + 1) * args.epochs / args.slices 112 | ) 113 | loaded = False 114 | 115 | for sl in range(args.slices): 116 | # Get slice hash using sharded lib. 117 | slice_hash = getShardHash( 118 | args.container, args.label, args.shard, until=(sl + 1) * slice_size 119 | ) 120 | 121 | # If checkpoints exists, skip the slice. 122 | if not os.path.exists( 123 | "containers/{}/cache/{}.pt".format(args.container, slice_hash) 124 | ): 125 | # Initialize state. 126 | elapsed_time = 0 127 | start_epoch = 0 128 | slice_epochs = int((sl + 1) * avg_epochs_per_slice) - int( 129 | sl * avg_epochs_per_slice 130 | ) 131 | 132 | # If weights are already in memory (from previous slice), skip loading. 133 | if not loaded: 134 | # Look for a recovery checkpoint for the slice. 135 | recovery_list = glob( 136 | "containers/{}/cache/{}_*.pt".format(args.container, slice_hash) 137 | ) 138 | if len(recovery_list) > 0: 139 | print( 140 | "Recovery mode for shard {} on slice {}".format(args.shard, sl) 141 | ) 142 | 143 | # Load weights. 144 | model.load_state_dict(torch.load(recovery_list[0])) 145 | start_epoch = int( 146 | recovery_list[0].split("/")[-1].split(".")[0].split("_")[1] 147 | ) 148 | 149 | # Load time 150 | with open( 151 | "containers/{}/times/{}_{}.time".format( 152 | args.container, slice_hash, start_epoch 153 | ), 154 | "r", 155 | ) as f: 156 | elapsed_time = float(f.read()) 157 | 158 | # If there is no recovery checkpoint and this slice is not the first, load previous slice. 159 | elif sl > 0: 160 | previous_slice_hash = getShardHash( 161 | args.container, args.label, args.shard, until=sl * slice_size 162 | ) 163 | 164 | # Load weights. 165 | model.load_state_dict( 166 | torch.load( 167 | "containers/{}/cache/{}.pt".format( 168 | args.container, previous_slice_hash 169 | ) 170 | ) 171 | ) 172 | 173 | # Mark model as loaded for next slices. 174 | loaded = True 175 | 176 | # If this is the first slice, no need to load anything. 177 | elif sl == 0: 178 | loaded = True 179 | 180 | # Actual training. 181 | train_time = 0.0 182 | 183 | for epoch in range(start_epoch, slice_epochs): 184 | epoch_start_time = time() 185 | 186 | for images, labels in fetchShardBatch( 187 | args.container, 188 | args.label, 189 | args.shard, 190 | args.batch_size, 191 | args.dataset, 192 | until=(sl + 1) * slice_size if sl < args.slices - 1 else None, 193 | ): 194 | 195 | # Convert data to torch format and send to selected device. 196 | gpu_images = torch.from_numpy(images).to( 197 | device 198 | ) # pylint: disable=no-member 199 | gpu_labels = torch.from_numpy(labels).to( 200 | device 201 | ) # pylint: disable=no-member 202 | 203 | forward_start_time = time() 204 | 205 | # Perform basic training step. 206 | logits = model(gpu_images) 207 | loss = loss_fn(logits, gpu_labels) 208 | 209 | optimizer.zero_grad() 210 | loss.backward() 211 | 212 | optimizer.step() 213 | 214 | train_time += time() - forward_start_time 215 | 216 | # Create a checkpoint every chkpt_interval. 217 | if ( 218 | args.chkpt_interval != -1 219 | and epoch % args.chkpt_interval == args.chkpt_interval - 1 220 | ): 221 | # Save weights 222 | torch.save( 223 | model.state_dict(), 224 | "containers/{}/cache/{}_{}.pt".format( 225 | args.container, slice_hash, epoch 226 | ), 227 | ) 228 | 229 | # Save time 230 | with open( 231 | "containers/{}/times/{}_{}.time".format( 232 | args.container, slice_hash, epoch 233 | ), 234 | "w", 235 | ) as f: 236 | f.write("{}\n".format(train_time + elapsed_time)) 237 | 238 | # Remove previous checkpoint. 239 | if os.path.exists( 240 | "containers/{}/cache/{}_{}.pt".format( 241 | args.container, slice_hash, epoch - args.chkpt_interval 242 | ) 243 | ): 244 | os.remove( 245 | "containers/{}/cache/{}_{}.pt".format( 246 | args.container, slice_hash, epoch - args.chkpt_interval 247 | ) 248 | ) 249 | if os.path.exists( 250 | "containers/{}/times/{}_{}.time".format( 251 | args.container, slice_hash, epoch - args.chkpt_interval 252 | ) 253 | ): 254 | os.remove( 255 | "containers/{}/times/{}_{}.time".format( 256 | args.container, slice_hash, epoch - args.chkpt_interval 257 | ) 258 | ) 259 | 260 | # When training is complete, save slice. 261 | torch.save( 262 | model.state_dict(), 263 | "containers/{}/cache/{}.pt".format(args.container, slice_hash), 264 | ) 265 | with open( 266 | "containers/{}/times/{}.time".format(args.container, slice_hash), "w" 267 | ) as f: 268 | f.write("{}\n".format(train_time + elapsed_time)) 269 | 270 | # Remove previous checkpoint. 271 | if os.path.exists( 272 | "containers/{}/cache/{}_{}.pt".format( 273 | args.container, slice_hash, args.epochs - args.chkpt_interval 274 | ) 275 | ): 276 | os.remove( 277 | "containers/{}/cache/{}_{}.pt".format( 278 | args.container, slice_hash, args.epochs - args.chkpt_interval 279 | ) 280 | ) 281 | if os.path.exists( 282 | "containers/{}/times/{}_{}.time".format( 283 | args.container, slice_hash, args.epochs - args.chkpt_interval 284 | ) 285 | ): 286 | os.remove( 287 | "containers/{}/times/{}_{}.time".format( 288 | args.container, slice_hash, args.epochs - args.chkpt_interval 289 | ) 290 | ) 291 | 292 | # If this is the last slice, create a symlink attached to it. 293 | if sl == args.slices - 1: 294 | os.symlink( 295 | "{}.pt".format(slice_hash), 296 | "containers/{}/cache/shard-{}:{}.pt".format( 297 | args.container, args.shard, args.label 298 | ), 299 | ) 300 | os.symlink( 301 | "{}.time".format(slice_hash), 302 | "containers/{}/times/shard-{}:{}.time".format( 303 | args.container, args.shard, args.label 304 | ), 305 | ) 306 | 307 | elif sl == args.slices - 1: 308 | os.symlink( 309 | "{}.pt".format(slice_hash), 310 | "containers/{}/cache/shard-{}:{}.pt".format( 311 | args.container, args.shard, args.label 312 | ), 313 | ) 314 | if not os.path.exists( 315 | "containers/{}/times/shard-{}:{}.time".format( 316 | args.container, args.shard, args.label 317 | ) 318 | ): 319 | os.symlink( 320 | "null.time", 321 | "containers/{}/times/shard-{}:{}.time".format( 322 | args.container, args.shard, args.label 323 | ), 324 | ) 325 | 326 | 327 | if args.test: 328 | # Load model weights from shard checkpoint (last slice). 329 | model.load_state_dict( 330 | torch.load( 331 | "containers/{}/cache/shard-{}:{}.pt".format( 332 | args.container, args.shard, args.label 333 | ) 334 | ) 335 | ) 336 | 337 | # Compute predictions batch per batch. 338 | outputs = np.empty((0, nb_classes)) 339 | for images, _ in fetchTestBatch(args.dataset, args.batch_size): 340 | # Convert data to torch format and send to selected device. 341 | gpu_images = torch.from_numpy(images).to(device) # pylint: disable=no-member 342 | 343 | if args.output_type == "softmax": 344 | # Actual batch prediction. 345 | logits = model(gpu_images) 346 | predictions = softmax(logits, dim=1).to("cpu") # Send back to cpu. 347 | 348 | # Convert back to numpy and concatenate with previous batches. 349 | outputs = np.concatenate((outputs, predictions.numpy())) 350 | 351 | else: 352 | # Actual batch prediction. 353 | logits = model(gpu_images) 354 | predictions = torch.argmax(logits, dim=1) # pylint: disable=no-member 355 | 356 | # Convert to one hot, send back to cpu, convert back to numpy and concatenate with previous batches. 357 | out = one_hot(predictions, nb_classes).to("cpu") 358 | outputs = np.concatenate((outputs, out.numpy())) 359 | 360 | # Save outputs in numpy format. 361 | outputs = np.array(outputs) 362 | np.save( 363 | "containers/{}/outputs/shard-{}:{}.npy".format( 364 | args.container, args.shard, args.label 365 | ), 366 | outputs, 367 | ) 368 | -------------------------------------------------------------------------------- /time.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | import argparse 4 | 5 | # Compute stats based on the execution time (cumulated feed-forward + backprop.) of the shards 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--container', help="Name of the container") 9 | args = parser.parse_args() 10 | 11 | t = pd.read_csv('containers/{}/times/times.tmp'.format(args.container), names=['time']) 12 | print('{},{}'.format(t['time'].sum(),t['time'].mean())) --------------------------------------------------------------------------------