├── README.md ├── model.py ├── proof └── .gitkeep ├── train.py ├── utils.py └── verify.py /README.md: -------------------------------------------------------------------------------- 1 | # Proof-of-Learning 2 | 3 | This repository is an implementation of the paper [Proof-of-Learning: Definitions and Practice](https://arxiv.org/abs/2103.05633), published in 42nd IEEE Symposium on 4 | Security and Privacy. In this paper, we introduce the concept of proof-of-learning in ML. Inspired by research on both proof-of-work and verified computing, we observe how a seminal training algorithm, gradient descent, accumulates secret information due to its stochasticity. This produces a natural construction for a proof-of-learning which demonstrates that a party has expended the compute require to obtain a set of model parameters correctly. For more details, please read the paper. 5 | 6 | We test our code on two datasets: CIFAR-10, and CIFAR-100. 7 | 8 | ### Dependency 9 | Our code is implemented and tested on PyTorch. Following packages are used: 10 | ``` 11 | numpy 12 | pytorch==1.6.0 13 | torchvision==0.7.0 14 | scipy==1.6.0 15 | ``` 16 | 17 | ### Train 18 | To train a model and create a proof-of-learning: 19 | ``` 20 | python train.py --save-freq [checkpointing interval] --dataset [any dataset in torchvision] --model [models defined in model.py or any torchvision model] 21 | ``` 22 | `save-freq` is checkpointing interval, denoted by k in the paper. There are a few other arguments that you could find at the end of the script. 23 | 24 | Note that the proposed algorithm does not interact with the training process, so it could be applied to any kinds of gradient-descent based models. 25 | 26 | 27 | ### Verify 28 | To verify a given proof-of-learning: 29 | ``` 30 | python verify.py --model-dir [path/to/the/proof] --dist [distance metric] --q [query budget] --delta [slack parameter] 31 | ``` 32 | Setting q to 0 or smaller will verify the whole proof, otherwise the top-q iterations for each epoch will be verified. More information about `q` and `delta` can be found in the paper. For `dist`, you could use one or more of `1`, `2`, `inf`, `cos` (if more than one, separate them by space). The first 3 are corresponding l_p norms, while `cos` is cosine distance. Note that if using more than one, the top-q iterations in terms of all distance metrics will be verified. 33 | 34 | Please make sure `lr`, `batch-sizr`, `epochs`, `dataset`, `model`, and `save-freq` are consistent with what used in `train.py`. 35 | 36 | ### Questions or suggestions 37 | If you have any questions or suggestions, feel free to raise an issue or send me an email at nickhengrui.jia@mail.utoronto.ca 38 | 39 | ### Citing this work 40 | If you use this repository for academic research, you are highly encouraged (though not required) to cite our paper: 41 | ``` 42 | @inproceedings{jia2021proofoflearning, 43 | title={Proof-of-Learning: Definitions and Practice}, 44 | author={Hengrui Jia and Mohammad Yaghini and Christopher A. Choquette-Choo and Natalie Dullerud and Anvith Thudi and Varun Chandrasekaran and Nicolas Papernot}, 45 | booktitle={Proceedings of the 42nd IEEE Symposium on Security and Privacy}, 46 | year={2021} 47 | } 48 | ``` 49 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch.nn.init as init 4 | 5 | 6 | class Simple_Conv(nn.Module): 7 | def __init__(self): 8 | super(Simple_Conv, self).__init__() 9 | self.conv1 = nn.Conv2d(1, 32, 5) 10 | self.pool = nn.MaxPool2d(2, 2) 11 | self.conv2 = nn.Conv2d(32, 64, 3) 12 | self.fc1 = nn.Linear(64 * 5 * 5, 128) 13 | self.fc2 = nn.Linear(128, 10) 14 | 15 | def forward(self, x): 16 | x = self.pool(F.relu(self.conv1(x))) 17 | x = self.pool(F.relu(self.conv2(x))) 18 | x = x.view(-1, 64 * 5 * 5) 19 | x = F.relu(self.fc1(x)) 20 | x = self.fc2(x) 21 | return x 22 | 23 | 24 | # https://github.com/akamaster/pytorch_resnet_cifar10/blob/master/resnet.py 25 | def _weights_init(m): 26 | classname = m.__class__.__name__ 27 | #print(classname) 28 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 29 | init.kaiming_normal_(m.weight) 30 | 31 | class LambdaLayer(nn.Module): 32 | def __init__(self, lambd): 33 | super(LambdaLayer, self).__init__() 34 | self.lambd = lambd 35 | 36 | def forward(self, x): 37 | return self.lambd(x) 38 | 39 | 40 | class BasicBlock(nn.Module): 41 | expansion = 1 42 | 43 | def __init__(self, in_planes, planes, stride=1, option='A'): 44 | super(BasicBlock, self).__init__() 45 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 46 | self.bn1 = nn.BatchNorm2d(planes) 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 48 | self.bn2 = nn.BatchNorm2d(planes) 49 | 50 | self.shortcut = nn.Sequential() 51 | if stride != 1 or in_planes != planes: 52 | if option == 'A': 53 | """ 54 | For CIFAR10 ResNet paper uses option A. 55 | """ 56 | self.shortcut = LambdaLayer(lambda x: 57 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 58 | elif option == 'B': 59 | self.shortcut = nn.Sequential( 60 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 61 | nn.BatchNorm2d(self.expansion * planes) 62 | ) 63 | 64 | def forward(self, x): 65 | out = F.relu(self.bn1(self.conv1(x))) 66 | out = self.bn2(self.conv2(out)) 67 | out += self.shortcut(x) 68 | out = F.relu(out) 69 | return out 70 | 71 | 72 | class ResNet(nn.Module): 73 | def __init__(self, block, num_blocks, num_classes=10): 74 | super(ResNet, self).__init__() 75 | self.in_planes = 16 76 | 77 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 78 | self.bn1 = nn.BatchNorm2d(16) 79 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 80 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 81 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 82 | self.linear = nn.Linear(64, num_classes) 83 | 84 | self.apply(_weights_init) 85 | 86 | def _make_layer(self, block, planes, num_blocks, stride): 87 | strides = [stride] + [1]*(num_blocks-1) 88 | layers = [] 89 | for stride in strides: 90 | layers.append(block(self.in_planes, planes, stride)) 91 | self.in_planes = planes * block.expansion 92 | 93 | return nn.Sequential(*layers) 94 | 95 | def forward(self, x): 96 | out = F.relu(self.bn1(self.conv1(x))) 97 | out = self.layer1(out) 98 | out = self.layer2(out) 99 | out = self.layer3(out) 100 | out = F.avg_pool2d(out, out.size()[3]) 101 | out = out.view(out.size(0), -1) 102 | out = self.linear(out) 103 | return out 104 | 105 | 106 | def resnet20(): 107 | return ResNet(BasicBlock, [3, 3, 3]) 108 | 109 | 110 | def resnet32(): 111 | return ResNet(BasicBlock, [5, 5, 5]) 112 | 113 | 114 | def resnet44(): 115 | return ResNet(BasicBlock, [7, 7, 7]) 116 | 117 | 118 | def resnet56(): 119 | return ResNet(BasicBlock, [9, 9, 9], 100) 120 | 121 | 122 | def resnet110(): 123 | return ResNet(BasicBlock, [18, 18, 18], 100) 124 | 125 | 126 | def resnet1202(): 127 | return ResNet(BasicBlock, [200, 200, 200]) 128 | 129 | 130 | def test(net): 131 | import numpy as np 132 | total_params = 0 133 | 134 | for x in filter(lambda p: p.requires_grad, net.parameters()): 135 | total_params += np.prod(x.data.numpy().shape) 136 | print("Total number of params", total_params) 137 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters())))) 138 | 139 | 140 | # https://github.com/weiaicunzai/pytorch-cifar100/blob/75cd6e633c0ffecd3ab49ef44ee8df9ed1919854/models/resnet.py 141 | class BasicBlock2(nn.Module): 142 | """Basic Block for resnet 18 and resnet 34 143 | """ 144 | 145 | #BasicBlock and BottleNeck block 146 | #have different output size 147 | #we use class attribute expansion 148 | #to distinct 149 | expansion = 1 150 | 151 | def __init__(self, in_channels, out_channels, stride=1): 152 | super().__init__() 153 | 154 | #residual function 155 | self.residual_function = nn.Sequential( 156 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False), 157 | nn.BatchNorm2d(out_channels), 158 | nn.ReLU(inplace=True), 159 | nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False), 160 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 161 | ) 162 | 163 | #shortcut 164 | self.shortcut = nn.Sequential() 165 | 166 | #the shortcut output dimension is not the same with residual function 167 | #use 1*1 convolution to match the dimension 168 | if stride != 1 or in_channels != BasicBlock.expansion * out_channels: 169 | self.shortcut = nn.Sequential( 170 | nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False), 171 | nn.BatchNorm2d(out_channels * BasicBlock.expansion) 172 | ) 173 | 174 | def forward(self, x): 175 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 176 | 177 | class BottleNeck2(nn.Module): 178 | """Residual block for resnet over 50 layers 179 | """ 180 | expansion = 4 181 | def __init__(self, in_channels, out_channels, stride=1): 182 | super().__init__() 183 | self.residual_function = nn.Sequential( 184 | nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False), 185 | nn.BatchNorm2d(out_channels), 186 | nn.ReLU(inplace=True), 187 | nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False), 188 | nn.BatchNorm2d(out_channels), 189 | nn.ReLU(inplace=True), 190 | nn.Conv2d(out_channels, out_channels * BottleNeck2.expansion, kernel_size=1, bias=False), 191 | nn.BatchNorm2d(out_channels * BottleNeck2.expansion), 192 | ) 193 | 194 | self.shortcut = nn.Sequential() 195 | 196 | if stride != 1 or in_channels != out_channels * BottleNeck2.expansion: 197 | self.shortcut = nn.Sequential( 198 | nn.Conv2d(in_channels, out_channels * BottleNeck2.expansion, stride=stride, kernel_size=1, bias=False), 199 | nn.BatchNorm2d(out_channels * BottleNeck2.expansion) 200 | ) 201 | 202 | def forward(self, x): 203 | return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x)) 204 | 205 | 206 | class ResNet2(nn.Module): 207 | 208 | def __init__(self, block, num_block, num_classes=100): 209 | super().__init__() 210 | 211 | self.in_channels = 64 212 | 213 | self.conv1 = nn.Sequential( 214 | nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False), 215 | nn.BatchNorm2d(64), 216 | nn.ReLU(inplace=True)) 217 | #we use a different inputsize than the original paper 218 | #so conv2_x's stride is 1 219 | self.conv2_x = self._make_layer(block, 64, num_block[0], 1) 220 | self.conv3_x = self._make_layer(block, 128, num_block[1], 2) 221 | self.conv4_x = self._make_layer(block, 256, num_block[2], 2) 222 | self.conv5_x = self._make_layer(block, 512, num_block[3], 2) 223 | self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 224 | self.fc = nn.Linear(512 * block.expansion, num_classes) 225 | 226 | def _make_layer(self, block, out_channels, num_blocks, stride): 227 | """make resnet layers(by layer i didnt mean this 'layer' was the 228 | same as a neuron netowork layer, ex. conv layer), one layer may 229 | contain more than one residual block 230 | Args: 231 | block: block type, basic block or bottle neck block 232 | out_channels: output depth channel number of this layer 233 | num_blocks: how many blocks per layer 234 | stride: the stride of the first block of this layer 235 | Return: 236 | return a resnet layer 237 | """ 238 | 239 | # we have num_block blocks per layer, the first block 240 | # could be 1 or 2, other blocks would always be 1 241 | strides = [stride] + [1] * (num_blocks - 1) 242 | layers = [] 243 | for stride in strides: 244 | layers.append(block(self.in_channels, out_channels, stride)) 245 | self.in_channels = out_channels * block.expansion 246 | 247 | return nn.Sequential(*layers) 248 | 249 | def forward(self, x): 250 | output = self.conv1(x) 251 | output = self.conv2_x(output) 252 | output = self.conv3_x(output) 253 | output = self.conv4_x(output) 254 | output = self.conv5_x(output) 255 | output = self.avg_pool(output) 256 | output = output.view(output.size(0), -1) 257 | output = self.fc(output) 258 | 259 | return output 260 | 261 | 262 | def resnet18(): 263 | """ return a ResNet 18 object 264 | """ 265 | return ResNet2(BasicBlock2, [2, 2, 2, 2]) 266 | 267 | def resnet34(): 268 | """ return a ResNet 34 object 269 | """ 270 | return ResNet2(BasicBlock2, [3, 4, 6, 3]) 271 | 272 | def resnet50(): 273 | """ return a ResNet 50 object 274 | """ 275 | return ResNet2(BottleNeck2, [3, 4, 6, 3]) 276 | 277 | def resnet101(): 278 | """ return a ResNet 101 object 279 | """ 280 | return ResNet2(BottleNeck2, [3, 4, 23, 3]) 281 | 282 | def resnet152(): 283 | """ return a ResNet 152 object 284 | """ 285 | return ResNet2(BottleNeck2, [3, 8, 36, 3]) 286 | -------------------------------------------------------------------------------- /proof/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cleverhans-lab/Proof-of-Learning/c286c7c0d2b45d80b4b6a7b0c2995690034927ff/proof/.gitkeep -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import hashlib 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torchvision 9 | from collections import OrderedDict 10 | 11 | import utils 12 | import model as custom_model 13 | 14 | 15 | def train(lr, batch_size, epochs, dataset, architecture, exp_id=None, sequence=None, 16 | model_dir=None, save_freq=None, num_gpu=torch.cuda.device_count(), verify=False, dec_lr=None, 17 | half=False, resume=False): 18 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 19 | 20 | if sequence is not None or model_dir is not None: 21 | resume = False 22 | 23 | try: 24 | trainset = utils.load_dataset(dataset, True) 25 | except: 26 | trainset = utils.load_dataset(dataset, True, download=True) 27 | 28 | if num_gpu > 1: 29 | net = nn.DataParallel(architecture()) 30 | batch_size = batch_size * num_gpu 31 | else: 32 | net = architecture() 33 | num_batch = trainset.__len__() / batch_size 34 | net.to(device) 35 | if dataset == 'MNIST': 36 | optimizer = optim.SGD(net.parameters(), lr=lr) 37 | scheduler = None 38 | elif dataset == 'CIFAR10': 39 | if dec_lr is None: 40 | dec_lr = [100, 150] 41 | optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4) 42 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, 43 | milestones=[round(i * num_batch) for i in dec_lr], 44 | gamma=0.1) 45 | elif dataset == 'CIFAR100': 46 | if dec_lr is None: 47 | dec_lr = [60, 120, 160] 48 | optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4) 49 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, 50 | milestones=[round(i * num_batch) for i in dec_lr], 51 | gamma=0.2) 52 | else: 53 | optimizer = optim.Adam(net.parameters(), lr=lr) 54 | scheduler = None 55 | 56 | criterion = torch.nn.CrossEntropyLoss().to(device) 57 | 58 | if model_dir is not None: 59 | # load a pre-trained model from model_dir if it is given 60 | state = torch.load(model_dir) 61 | new_state_dict = OrderedDict() 62 | try: 63 | # in case the checkpoint is from a parallelized model 64 | for k, v in state['net'].items(): 65 | name = "module." + k 66 | new_state_dict[name] = v 67 | net.load_state_dict(new_state_dict) 68 | except: 69 | net.load_state_dict(state['net']) 70 | optimizer.load_state_dict(state['optimizer']) 71 | if scheduler is not None: 72 | try: 73 | scheduler.load_state_dict(state['scheduler']) 74 | except: 75 | scheduler = None 76 | 77 | if half: 78 | net.half().float() 79 | 80 | if sequence is None: 81 | # if a training sequence is not given, create a new one 82 | train_size = trainset.__len__() 83 | sequence = utils.create_sequences(batch_size, train_size, epochs) 84 | 85 | ind = None 86 | if save_freq is not None and save_freq > 0: 87 | # save the sequence of data indices if save_freq is not none 88 | save_dir = os.path.join("proof", f"{dataset}_{exp_id}") 89 | if not os.path.exists(save_dir): 90 | os.mkdir(save_dir) 91 | else: 92 | if resume: 93 | try: 94 | ind = -1 95 | # find the most recent checkpoint 96 | while os.path.exists(os.path.join(save_dir, f"model_step_{ind + 1}")): 97 | ind = ind + 1 98 | if ind >= 0: 99 | model_dir = os.path.join(save_dir, f"model_step_{ind}") 100 | state = torch.load(model_dir) 101 | new_state_dict = OrderedDict() 102 | try: 103 | for k, v in state['net'].items(): 104 | name = "module." + k 105 | new_state_dict[name] = v 106 | net.load_state_dict(new_state_dict) 107 | except: 108 | net.load_state_dict(state['net']) 109 | optimizer.load_state_dict(state['optimizer']) 110 | if scheduler is not None: 111 | try: 112 | scheduler.load_state_dict(state['scheduler']) 113 | except: 114 | scheduler = None 115 | sequence = np.load(os.path.join(save_dir, "indices.npy")) 116 | sequence = sequence[ind:] 117 | print('resume training') 118 | except: 119 | print('resume failed') 120 | pass 121 | if ind == -1: 122 | ind = None 123 | 124 | np.save(os.path.join(save_dir, "indices.npy"), sequence) 125 | 126 | num_step = sequence.shape[0] 127 | 128 | sequence = np.reshape(sequence, -1) 129 | subset = torch.utils.data.Subset(trainset, sequence) 130 | trainloader = torch.utils.data.DataLoader(subset, batch_size=batch_size, num_workers=0, pin_memory=True) 131 | net.train() 132 | 133 | if save_freq is not None and save_freq > 0: 134 | m = hashlib.sha256() 135 | for d in subset.dataset.data: 136 | m.update(d.__str__().encode('utf-8')) 137 | f = open(os.path.join(save_dir, "hash.txt"), "x") 138 | f.write(m.hexdigest()) 139 | f.close() 140 | 141 | for i, data in enumerate(trainloader, 0): 142 | if save_freq is not None and i % save_freq == 0 and save_freq > 0: 143 | # save the checkpoints every save_freq iterations 144 | state = {'net': net.state_dict(), 145 | 'optimizer': optimizer.state_dict()} 146 | if scheduler is not None: 147 | state['scheduler'] = scheduler.state_dict() 148 | if ind is None: 149 | torch.save(state, os.path.join(save_dir, f"model_step_{i}")) 150 | else: 151 | torch.save(state, os.path.join(save_dir, f"model_step_{i+ind}")) 152 | 153 | inputs, labels = data[0].to(device), data[1].to(device) 154 | optimizer.zero_grad() 155 | outputs = net(inputs) 156 | loss = criterion(outputs, labels) 157 | loss.backward() 158 | optimizer.step() 159 | if scheduler is not None: 160 | scheduler.step() 161 | 162 | if i > 0 and i % round(num_batch) == 0 and verify: 163 | print(f'Epoch {i // round(num_batch)}') 164 | validate(dataset, net, batch_size) 165 | net.train() 166 | 167 | if save_freq is not None and save_freq > 0: 168 | # for a model with n training steps, n+1 checkpoints will be saved 169 | state = {'net': net.state_dict(), 170 | 'optimizer': optimizer.state_dict()} 171 | if scheduler is not None: 172 | state['scheduler'] = scheduler.state_dict() 173 | torch.save(state, os.path.join(save_dir, f"model_step_{num_step}")) 174 | 175 | return net 176 | 177 | 178 | def validate(dataset, model, batch_size=128): 179 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 180 | testset = utils.load_dataset(dataset, False) 181 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, 182 | shuffle=False, num_workers=2, pin_memory=True) 183 | model.eval() 184 | correct = 0 185 | total = 0 186 | with torch.no_grad(): 187 | for data in testloader: 188 | images, labels = data[0].to(device), data[1].to(device) 189 | outputs = model(images) 190 | _, predicted = torch.max(outputs.data, 1) 191 | total += labels.size(0) 192 | correct += (predicted == labels).sum().item() 193 | 194 | print(f'Accuracy: {100 * correct / total} %') 195 | return correct / total 196 | 197 | 198 | if __name__ == '__main__': 199 | parser = argparse.ArgumentParser() 200 | parser.add_argument('--batch-size', type=int, default=128) 201 | parser.add_argument('--lr', type=float, default=0.01) 202 | parser.add_argument('--epochs', type=int, default=1) 203 | parser.add_argument('--dataset', type=str, default="CIFAR10") 204 | parser.add_argument('--model', type=str, default="resnet20", 205 | help="models defined in model.py or any torchvision model.\n" 206 | "Recommendation for CIFAR-10: resnet20/32/44/56/110/1202\n" 207 | "Recommendation for CIFAR-100: resnet18/34/50/101/152" 208 | ) 209 | parser.add_argument('--id', help='experiment id', type=str, default='test') 210 | parser.add_argument('--save-freq', type=int, default=100, help='frequence of saving checkpoints') 211 | parser.add_argument('--num-gpu', type=int, default=torch.cuda.device_count()) 212 | parser.add_argument('--milestone', nargs='+', type=int, default=[100, 150]) 213 | parser.add_argument('--verify', type=int, default=0) 214 | arg = parser.parse_args() 215 | 216 | print(f'trying to allocate {arg.num_gpu} gpus') 217 | try: 218 | architecture = eval(f"custom_model.{arg.model}") 219 | except: 220 | architecture = eval(f"torchvision.models.{arg.model}") 221 | trained_model = train(arg.lr, arg.batch_size, arg.epochs, arg.dataset, architecture, exp_id=arg.id, 222 | save_freq=arg.save_freq, num_gpu=arg.num_gpu, dec_lr=arg.milestone, 223 | verify=arg.verify, resume=True) 224 | validate(arg.dataset, trained_model) 225 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy import stats 4 | import torch.nn as nn 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | 8 | 9 | def get_parameters(net, numpy=False): 10 | # get weights from a torch model as a list of numpy arrays 11 | parameter = torch.cat([i.data.reshape([-1]) for i in list(net.parameters())]) 12 | if numpy: 13 | return parameter.cpu().numpy() 14 | else: 15 | return parameter 16 | 17 | 18 | def set_parameters(net, parameters, device): 19 | # load weights from a list of numpy arrays to a torch model 20 | for i, (name, param) in enumerate(net.named_parameters()): 21 | param.data = torch.Tensor(parameters[i]).to(device) 22 | return net 23 | 24 | 25 | def create_sequences(batch_size, dataset_size, epochs): 26 | # create a sequence of data indices used for training 27 | sequence = np.concatenate([np.random.default_rng().choice(dataset_size, size=dataset_size, replace=False) 28 | for i in range(epochs)]) 29 | num_batch = int(len(sequence) // batch_size) 30 | return np.reshape(sequence[:num_batch * batch_size], [num_batch, batch_size]) 31 | 32 | 33 | def consistent_type(model, architecture=None, 34 | device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'), half=False): 35 | # this function takes in directory to where model is saved, model weights as a list of numpy array, 36 | # or a torch model and outputs model weights as a list of numpy array 37 | if isinstance(model, str): 38 | assert architecture is not None 39 | state = torch.load(model) 40 | net = architecture() 41 | net.load_state_dict(state['net']) 42 | weights = get_parameters(net) 43 | elif isinstance(model, np.ndarray): 44 | weights = torch.tensor(model) 45 | elif not isinstance(model, torch.Tensor): 46 | weights = get_parameters(model) 47 | else: 48 | weights = model 49 | if half: 50 | weights = weights.half() 51 | return weights.to(device) 52 | 53 | 54 | def parameter_distance(model1, model2, order=2, architecture=None, half=False): 55 | # compute the difference between 2 checkpoints 56 | weights1 = consistent_type(model1, architecture, half=half) 57 | weights2 = consistent_type(model2, architecture, half=half) 58 | if not isinstance(order, list): 59 | orders = [order] 60 | else: 61 | orders = order 62 | res_list = [] 63 | for o in orders: 64 | if o == 'inf': 65 | o = np.inf 66 | if o == 'cos' or o == 'cosine': 67 | res = (1 - torch.dot(weights1, weights2) / 68 | (torch.norm(weights1) * torch.norm(weights1))).cpu().numpy() 69 | else: 70 | if o != np.inf: 71 | try: 72 | o = int(o) 73 | except: 74 | raise TypeError("input metric for distance is not understandable") 75 | res = torch.norm(weights1 - weights2, p=o).cpu().numpy() 76 | if isinstance(res, np.ndarray): 77 | res = float(res) 78 | res_list.append(res) 79 | return res_list 80 | 81 | 82 | def load_dataset(dataset, train, download=False): 83 | try: 84 | dataset_class = eval(f"torchvision.datasets.{dataset}") 85 | except: 86 | raise NotImplementedError(f"Dataset {dataset} is not implemented by pytorch.") 87 | 88 | if dataset == "MNIST" or dataset == "FashionMNIST": 89 | transform = transforms.Compose( 90 | [transforms.ToTensor(), 91 | transforms.Normalize((0.5,), (0.5,))]) 92 | elif dataset == "CIFAR100": 93 | if train: 94 | transform = transforms.Compose([ 95 | transforms.RandomCrop(32, padding=4), 96 | transforms.RandomHorizontalFlip(), 97 | transforms.RandomRotation(15), 98 | transforms.ToTensor(), 99 | transforms.Normalize((0.5070751592371323, 0.48654887331495095, 0.4409178433670343), 100 | (0.2673342858792401, 0.2564384629170883, 0.27615047132568404))]) 101 | else: 102 | transform = transforms.Compose([ 103 | transforms.ToTensor(), 104 | transforms.Normalize((0.5070751592371323, 0.48654887331495095, 0.4409178433670343), 105 | (0.2673342858792401, 0.2564384629170883, 0.27615047132568404))]) 106 | else: 107 | if train: 108 | transform = transforms.Compose([ 109 | transforms.RandomHorizontalFlip(), 110 | transforms.RandomCrop(32, 4), 111 | transforms.ToTensor(), 112 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 113 | else: 114 | transform = transforms.Compose([ 115 | transforms.ToTensor(), 116 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 117 | 118 | data = dataset_class(root='./data', train=train, download=download, transform=transform) 119 | return data 120 | 121 | 122 | def ks_test(reference, rvs): 123 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 124 | with torch.no_grad(): 125 | ecdf = torch.arange(rvs.shape[0]).float() / torch.tensor(rvs.shape) 126 | return torch.max(torch.abs(reference(torch.sort(rvs)[0]).to(device) - ecdf.to(device))) 127 | 128 | 129 | def check_weights_initialization(param, method): 130 | if method == 'default': 131 | # kaimin uniform (default for weights of nn.Conv and nn.Linear) 132 | fan = nn.init._calculate_correct_fan(param, 'fan_in') 133 | gain = nn.init.calculate_gain('leaky_relu', np.sqrt(5)) 134 | std = gain / np.sqrt(fan) 135 | bound = np.sqrt(3.0) * std 136 | reference = torch.distributions.uniform.Uniform(-bound, bound).cdf 137 | elif method == 'resnet_cifar': 138 | # kaimin normal 139 | fan = nn.init._calculate_correct_fan(param, 'fan_in') 140 | gain = nn.init.calculate_gain('leaky_relu', 0) 141 | std = gain / np.sqrt(fan) 142 | reference = torch.distributions.normal.Normal(0, std).cdf 143 | elif method == 'resnet': 144 | # kaimin normal (default in conv layers of pytorch resnet) 145 | fan = nn.init._calculate_correct_fan(param, 'fan_out') 146 | gain = nn.init.calculate_gain('relu', 0) 147 | std = gain / np.sqrt(fan) 148 | reference = torch.distributions.normal.Normal(0, std).cdf 149 | elif method == 'default_bias': 150 | # default for bias of nn.Conv and nn.Linear 151 | weight, param = param 152 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight) 153 | bound = 1 / np.sqrt(fan_in) 154 | reference = torch.distributions.uniform.Uniform(-bound, bound).cdf 155 | else: 156 | raise NotImplementedError("Input initialization strategy is not implemented.") 157 | 158 | param = param.reshape(-1) 159 | ks_stats = ks_test(reference, param).cpu().item() 160 | return stats.kstwo.sf(ks_stats, param.shape[0]) 161 | 162 | 163 | def check_weights_initialization_scipy(param, method): 164 | if method == 'default': 165 | # kaimin uniform (default for weights of nn.Conv and nn.Linear) 166 | fan = nn.init._calculate_correct_fan(param, 'fan_in') 167 | gain = nn.init.calculate_gain('leaky_relu', np.sqrt(5)) 168 | std = gain / np.sqrt(fan) 169 | bound = np.sqrt(3.0) * std 170 | reference = stats.uniform(loc=-bound, scale=bound * 2).cdf 171 | elif method == 'resnet': 172 | # kaimin normal (default in conv layers of pytorch resnet) 173 | fan = nn.init._calculate_correct_fan(param, 'fan_out') 174 | gain = nn.init.calculate_gain('relu', 0) 175 | std = gain / np.sqrt(fan) 176 | reference = stats.norm(loc=0, scale=std).cdf 177 | elif method == 'default_bias': 178 | # default for bias of nn.Conv and nn.Linear 179 | weight, param = param 180 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight) 181 | bound = 1 / np.sqrt(fan_in) 182 | reference = stats.uniform(loc=-bound, scale=bound * 2).cdf 183 | else: 184 | raise NotImplementedError("Input initialization strategy is not implemented.") 185 | 186 | param = param.detach().numpy().reshape(-1) 187 | return stats.kstest(param, reference)[1] 188 | -------------------------------------------------------------------------------- /verify.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import hashlib 4 | import numpy as np 5 | import torch 6 | import torchvision 7 | from functools import reduce 8 | 9 | import utils 10 | from train import train 11 | import model as custom_model 12 | 13 | 14 | def verify_all(dir, lr, batch_size, dataset, architecture, save_freq, order, threshold, half=0): 15 | if not os.path.exists(dir): 16 | raise FileNotFoundError("Model directory not found") 17 | sequence = np.load(os.path.join(dir, "indices.npy")) 18 | 19 | if not isinstance(order, list): 20 | order = [order] 21 | threshold = [threshold] 22 | else: 23 | assert len(order) == len(threshold) 24 | 25 | dist_list = [[] for i in range(len(order))] 26 | 27 | target_model = os.path.join(dir, f"model_step_0") 28 | for i in range(0, sequence.shape[0], save_freq): 29 | previous_state = target_model 30 | if i + save_freq >= sequence.shape[0]: 31 | target_model = os.path.join(dir, f"model_step_{sequence.shape[0]}") 32 | reproduce = train(lr, batch_size, 0, dataset, architecture, model_dir=previous_state, 33 | sequence=sequence[i:], half=half) 34 | else: 35 | target_model = os.path.join(dir, f"model_step_{i + save_freq}") 36 | reproduce = train(lr, batch_size, 0, dataset, architecture, model_dir=previous_state, 37 | sequence=sequence[i:i+save_freq], half=half) 38 | res = utils.parameter_distance(target_model, reproduce, order=order, 39 | architecture=architecture, half=half) 40 | for j in range(len(order)): 41 | dist_list[j].append(res[j]) 42 | 43 | dist_list = np.array(dist_list) 44 | for k in range(len(order)): 45 | print(f"Distance metric: {order[k]} || threshold: {threshold[k]}") 46 | print(f"Average distance: {np.average(dist_list[k])}") 47 | above_threshold = np.sum(dist_list[k] > threshold[k]) 48 | if above_threshold == 0: 49 | print("None of the steps is above the threshold, the proof-of-learning is valid.") 50 | else: 51 | print(f"{above_threshold} / {dist_list[k].shape[0]} " 52 | f"({100 * np.average(dist_list[k] > threshold[k])}%) " 53 | f"of the steps are above the threshold, the proof-of-learning is invalid.") 54 | return dist_list 55 | 56 | 57 | def verify_topq(dir, lr, batch_size, dataset, architecture, save_freq, order, threshold, epochs=1, q=10, half=0): 58 | if not os.path.exists(dir): 59 | raise FileNotFoundError("Model directory not found") 60 | sequence = np.load(os.path.join(dir, "indices.npy")) 61 | 62 | if not isinstance(order, list): 63 | order = [order] 64 | threshold = [threshold] 65 | else: 66 | assert len(order) == len(threshold) 67 | 68 | ckpt_per_epoch = sequence.shape[0] / epochs / save_freq 69 | res = [] 70 | 71 | for epoch in range(epochs): 72 | print(f"Verifying epoch {epoch + 1}/{epochs}") 73 | start = np.round(ckpt_per_epoch * epoch).__int__() 74 | end = np.round(ckpt_per_epoch * (epoch + 1)).__int__() 75 | dist_list = [[] for i in range(len(order))] 76 | next_model = os.path.join(dir, f"model_step_{start * save_freq}") 77 | for i in range(start, end): 78 | current_model = next_model 79 | if (i + 1) * save_freq >= sequence.shape[0]: 80 | next_model = os.path.join(dir, f"model_step_{sequence.shape[0]}") 81 | else: 82 | next_model = os.path.join(dir, f"model_step_{(i + 1) * save_freq}") 83 | res = utils.parameter_distance(current_model, next_model, order=order, 84 | architecture=architecture, half=half) 85 | for j in range(len(order)): 86 | dist_list[j].append(res[j]) 87 | 88 | dist_arr = np.array(dist_list) 89 | topq_steps = np.argpartition(dist_arr, -q, axis=1)[:, -q:] 90 | if len(order) > 1: 91 | # union the top-q steps of all distance metrics to avoid redundant computation 92 | topq_steps = reduce(np.union1d, list(topq_steps)) 93 | 94 | dist_list = [[] for i in range(len(order))] 95 | for ind in topq_steps: 96 | step = int((start + ind) * save_freq) 97 | current_model = os.path.join(dir, f"model_step_{step}") 98 | if step + save_freq >= sequence.shape[0]: 99 | target_model = os.path.join(dir, f"model_step_{sequence.shape[0]}") 100 | reproduce = train(lr, batch_size, 0, dataset, architecture, model_dir=current_model, 101 | sequence=sequence[step:], half=half) 102 | else: 103 | target_model = os.path.join(dir, f"model_step_{step + save_freq}") 104 | reproduce = train(lr, batch_size, 0, dataset, architecture, model_dir=current_model, 105 | sequence=sequence[step:step+save_freq], half=half) 106 | res = utils.parameter_distance(target_model, reproduce, order=order, 107 | architecture=architecture, half=half) 108 | for j in range(len(order)): 109 | dist_list[j].append(res[j]) 110 | 111 | dist_list = np.array(dist_list) 112 | for k in range(len(order)): 113 | print(f"Distance metric: {order[k]} || threshold: {threshold[k]} || Q={q}") 114 | print(f"Average top-q distance: {np.average(dist_list[k])}") 115 | above_threshold = np.sum(dist_list[k] > threshold[k]) 116 | if above_threshold == 0: 117 | print("None of the steps is above the threshold, the proof-of-learning is valid.") 118 | else: 119 | print(f"{above_threshold} / {dist_list[k].shape[0]} " 120 | f"({100 * np.average(dist_list[k] > threshold[k])}%)" 121 | f" of the steps are above the threshold, the proof-of-learning is invalid.") 122 | res.append(dist_list) 123 | return res 124 | 125 | 126 | def verify_initialization(dir, architecture, threshold=0.01, net=None, verbose=True): 127 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 128 | if net is None: 129 | net = architecture() 130 | state = torch.load(os.path.join(dir, "model_step_0")) 131 | net.load_state_dict(state['net']) 132 | net.to(device) 133 | model_name = architecture.__name__ 134 | if model_name in ['resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnet1202']: 135 | model_type = 'resnet_cifar' 136 | elif model_name == 'resnet50': 137 | model_type = 'resnet_cifar100' 138 | elif 'resnet' in model_name: 139 | model_type = 'resnet' 140 | else: 141 | model_type = 'default' 142 | p_list = [] 143 | if model_type == 'resnet': 144 | for name, param in net.named_parameters(): 145 | if 'weight' in name and 'conv' in name: 146 | p_list.append(utils.check_weights_initialization(param, 'resnet')) 147 | elif 'weight' in name and 'fc' in name: 148 | p_list.append(utils.check_weights_initialization(param, 'default')) 149 | elif 'bias' in name and ('fc' in name or 'linear' in name): 150 | weight = net.state_dict()[name.replace('bias', 'weight')] 151 | p_list.append(utils.check_weights_initialization([weight, param], 'default_bias')) 152 | elif model_type == 'resnet_cifar100': 153 | for name, param in net.named_parameters(): 154 | if len(param.shape) == 4: 155 | p_list.append(utils.check_weights_initialization(param, 'default')) 156 | elif 'weight' in name and 'fc' in name: 157 | p_list.append(utils.check_weights_initialization(param, 'default')) 158 | elif 'bias' in name and ('fc' in name or 'linear' in name): 159 | weight = net.state_dict()[name.replace('bias', 'weight')] 160 | p_list.append(utils.check_weights_initialization([weight, param], 'default_bias')) 161 | elif model_type == 'resnet_cifar': 162 | for name, param in net.named_parameters(): 163 | if 'fc' in name or 'conv' in name or 'linear' in name: 164 | if 'weight' in name: 165 | p_list.append(utils.check_weights_initialization(param, 'resnet_cifar')) 166 | elif 'bias' in name: 167 | weight = net.state_dict()[name.replace('bias', 'weight')] 168 | p_list.append(utils.check_weights_initialization([weight, param], 'default_bias')) 169 | else: 170 | for name, param in net.named_parameters(): 171 | if 'fc' in name or 'conv' in name or 'linear' in name: 172 | if 'weight' in name: 173 | p_list.append(utils.check_weights_initialization(param, 'default')) 174 | elif 'bias' in name: 175 | weight = net.state_dict()[name.replace('bias', 'weight')] 176 | p_list.append(utils.check_weights_initialization([weight, param], 'default_bias')) 177 | 178 | if verbose: 179 | if np.min(p_list) < threshold: 180 | print(f"The initialized weights does not follow the initialization strategy." 181 | f"The minimum p value is {np.min(p_list)} < threshold ({threshold})." 182 | f"The proof-of-learning is not valid.") 183 | else: 184 | print("The proof-of-learning passed the initialization verification.") 185 | return p_list 186 | 187 | 188 | def verify_hash(dir, dataset): 189 | if not os.path.exists(dir): 190 | raise FileNotFoundError("Model directory not found") 191 | sequence = np.load(os.path.join(dir, "indices.npy")) 192 | with open(os.path.join(dir, "hash.txt"), "r") as f: 193 | hash = f.read() 194 | 195 | trainset = utils.load_dataset(dataset, True) 196 | subset = torch.utils.data.Subset(trainset, sequence) 197 | m = hashlib.sha256() 198 | for d in subset.dataset.data: 199 | m.update(d.__str__().encode('utf-8')) 200 | 201 | if hash != m.hexdigest(): 202 | print("Hash doesn't match. The proof is invalid") 203 | else: 204 | print("Hash of the proof is valid.") 205 | 206 | 207 | if __name__ == '__main__': 208 | parser = argparse.ArgumentParser() 209 | parser.add_argument('--batch-size', type=int, default=128) 210 | parser.add_argument('--lr', type=float, default=0.01) 211 | parser.add_argument('--epochs', type=int, default=2) 212 | parser.add_argument('--dataset', type=str, default="CIFAR10") 213 | parser.add_argument('--model', type=str, default="resnet20", 214 | help="models defined in model.py or any torchvision model.\n" 215 | "Recommendation for CIFAR-10: resnet20/32/44/56/110/1202\n" 216 | "Recommendation for CIFAR-100: resnet18/34/50/101/152" 217 | ) 218 | parser.add_argument('--model-dir', help='path/to/the/proof', type=str, default='proof/CIFAR10_test') 219 | parser.add_argument('--save-freq', type=int, default=100, help='frequence of saving checkpoints') 220 | parser.add_argument('--dist', type=str, nargs='+', default=['1', '2', 'inf', 'cos'], 221 | help='metric for computing distance, cos, 1, 2, or inf') 222 | parser.add_argument('--q', type=int, default=2, help="Set to >1 to enable top-q verification," 223 | "otherwise all steps will be verified.") 224 | parser.add_argument('--delta', type=float, default=[1000, 10, 0.1, 0.01], 225 | help='threshold for verification') 226 | 227 | arg = parser.parse_args() 228 | 229 | try: 230 | architecture = eval(f"custom_model.{arg.model}") 231 | except: 232 | architecture = eval(f"torchvision.models.{arg.model}") 233 | 234 | verify_initialization(arg.model_dir, architecture) 235 | verify_hash(arg.model_dir, arg.dataset) 236 | 237 | if arg.q > 0: 238 | verify_topq(arg.model_dir, arg.lr, arg.batch_size, arg.dataset, architecture, arg.save_freq, 239 | arg.dist, arg.delta, arg.epochs, q=arg.q) 240 | else: 241 | verify_all(arg.model_dir, arg.lr, arg.batch_size, arg.dataset, architecture, arg.save_freq, 242 | arg.dist, arg.delta) 243 | --------------------------------------------------------------------------------