├── models ├── __init__.py ├── alldnet.py ├── vgg.py ├── mobilenet.py ├── googlenet.py ├── resnext.py ├── densenet.py ├── resnet.py └── nasnet.py ├── requirements.txt ├── LICENSE ├── README.md ├── utils.py ├── train_inside.py └── train_cross.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import * 2 | from .resnet import * 3 | from .resnext import * 4 | from .densenet import * 5 | from .googlenet import * 6 | from .mobilenet import * 7 | from .nasnet import * -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2020.4.5.1 2 | cycler==0.10.0 3 | kiwisolver==1.2.0 4 | matplotlib==3.2.1 5 | numpy==1.18.3 6 | Pillow==7.1.1 7 | pyparsing==2.4.7 8 | python-dateutil==2.8.1 9 | six==1.14.0 10 | torch==1.4.0 11 | torchvision==0.5.0 12 | opencv-contrib-python==4.4.0.44 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yangsibo Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/alldnet.py: -------------------------------------------------------------------------------- 1 | '''LeNet in PyTorch.''' 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | class AllDNet(nn.Module): 7 | def __init__(self): 8 | super(AllDNet, self).__init__() 9 | self.conv1 = nn.Conv2d(3, 6, 5) 10 | self.conv2 = nn.Conv2d(6, 16, 5) 11 | # self.conv2 = nn.Linear(6*14*14, 16*10*10) 12 | self.fc1 = nn.Linear(16*5*5, 120) 13 | self.fc2 = nn.Linear(120, 84) 14 | self.fc3 = nn.Linear(84, 10) 15 | 16 | def forward(self, x): 17 | activations = [] 18 | out = F.relu(self.conv1(x)) 19 | out = F.max_pool2d(out, 2) 20 | # out = out.view(out.size(0), -1) 21 | # activations.append(out) 22 | out = F.relu(self.conv2(out)) 23 | # out = out.view(out.size(0), 16, 10, -1) 24 | out = F.max_pool2d(out, 2) 25 | out = out.view(out.size(0), -1) 26 | activations.append(out) 27 | out = F.relu(self.fc1(out)) 28 | activations.append(out) 29 | out = F.relu(self.fc2(out)) 30 | activations.append(out) 31 | out = self.fc3(out) 32 | return out, activations 33 | 34 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG11/13/16/19 in Pytorch.''' 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | 7 | cfg = { 8 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 10 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 11 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 12 | } 13 | 14 | 15 | class VGG(nn.Module): 16 | def __init__(self, vgg_name): 17 | super(VGG, self).__init__() 18 | self.features = self._make_layers(cfg[vgg_name]) 19 | self.classifier = nn.Linear(512, 10) 20 | 21 | def forward(self, x): 22 | out = self.features(x) 23 | out = out.view(out.size(0), -1) 24 | out = self.classifier(out) 25 | return out 26 | 27 | def _make_layers(self, cfg): 28 | layers = [] 29 | in_channels = 3 30 | for x in cfg: 31 | if x == 'M': 32 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 33 | else: 34 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(x), 36 | nn.ReLU(inplace=True)] 37 | in_channels = x 38 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 39 | return nn.Sequential(*layers) 40 | 41 | # net = VGG('VGG11') 42 | # x = torch.randn(2,3,32,32) 43 | # print(net(Variable(x)).size()) 44 | -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | '''MobileNet in PyTorch. 2 | 3 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 4 | for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from torch.autograd import Variable 11 | 12 | 13 | class Block(nn.Module): 14 | '''Depthwise conv + Pointwise conv''' 15 | def __init__(self, in_planes, out_planes, stride=1): 16 | super(Block, self).__init__() 17 | self.conv1 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=in_planes, bias=False) 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 20 | self.bn2 = nn.BatchNorm2d(out_planes) 21 | 22 | def forward(self, x): 23 | out = F.relu(self.bn1(self.conv1(x))) 24 | out = F.relu(self.bn2(self.conv2(out))) 25 | return out 26 | 27 | 28 | class MobileNet(nn.Module): 29 | # (128,2) means conv planes=128, conv stride=2, by default conv stride=1 30 | cfg = [64, (128,2), 128, (256,2), 256, (512,2), 512, 512, 512, 512, 512, (1024,2), 1024] 31 | 32 | def __init__(self, num_classes=10): 33 | super(MobileNet, self).__init__() 34 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 35 | self.bn1 = nn.BatchNorm2d(32) 36 | self.layers = self._make_layers(in_planes=32) 37 | self.linear = nn.Linear(1024, num_classes) 38 | 39 | def _make_layers(self, in_planes): 40 | layers = [] 41 | for x in self.cfg: 42 | out_planes = x if isinstance(x, int) else x[0] 43 | stride = 1 if isinstance(x, int) else x[1] 44 | layers.append(Block(in_planes, out_planes, stride)) 45 | in_planes = out_planes 46 | return nn.Sequential(*layers) 47 | 48 | def forward(self, x): 49 | out = F.relu(self.bn1(self.conv1(x))) 50 | out = self.layers(out) 51 | out = F.avg_pool2d(out, 2) 52 | out = out.view(out.size(0), -1) 53 | out = self.linear(out) 54 | return out 55 | 56 | 57 | def test(): 58 | net = MobileNet() 59 | x = torch.randn(1,3,32,32) 60 | y = net(Variable(x)) 61 | print(y.size()) 62 | 63 | # test() 64 | -------------------------------------------------------------------------------- /models/googlenet.py: -------------------------------------------------------------------------------- 1 | '''GoogLeNet with PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from torch.autograd import Variable 7 | 8 | 9 | class Inception(nn.Module): 10 | def __init__(self, in_planes, n1x1, n3x3red, n3x3, n5x5red, n5x5, pool_planes): 11 | super(Inception, self).__init__() 12 | # 1x1 conv branch 13 | self.b1 = nn.Sequential( 14 | nn.Conv2d(in_planes, n1x1, kernel_size=1), 15 | nn.BatchNorm2d(n1x1), 16 | nn.ReLU(True), 17 | ) 18 | 19 | # 1x1 conv -> 3x3 conv branch 20 | self.b2 = nn.Sequential( 21 | nn.Conv2d(in_planes, n3x3red, kernel_size=1), 22 | nn.BatchNorm2d(n3x3red), 23 | nn.ReLU(True), 24 | nn.Conv2d(n3x3red, n3x3, kernel_size=3, padding=1), 25 | nn.BatchNorm2d(n3x3), 26 | nn.ReLU(True), 27 | ) 28 | 29 | # 1x1 conv -> 5x5 conv branch 30 | self.b3 = nn.Sequential( 31 | nn.Conv2d(in_planes, n5x5red, kernel_size=1), 32 | nn.BatchNorm2d(n5x5red), 33 | nn.ReLU(True), 34 | nn.Conv2d(n5x5red, n5x5, kernel_size=3, padding=1), 35 | nn.BatchNorm2d(n5x5), 36 | nn.ReLU(True), 37 | nn.Conv2d(n5x5, n5x5, kernel_size=3, padding=1), 38 | nn.BatchNorm2d(n5x5), 39 | nn.ReLU(True), 40 | ) 41 | 42 | # 3x3 pool -> 1x1 conv branch 43 | self.b4 = nn.Sequential( 44 | nn.MaxPool2d(3, stride=1, padding=1), 45 | nn.Conv2d(in_planes, pool_planes, kernel_size=1), 46 | nn.BatchNorm2d(pool_planes), 47 | nn.ReLU(True), 48 | ) 49 | 50 | def forward(self, x): 51 | y1 = self.b1(x) 52 | y2 = self.b2(x) 53 | y3 = self.b3(x) 54 | y4 = self.b4(x) 55 | return torch.cat([y1,y2,y3,y4], 1) 56 | 57 | 58 | class GoogLeNet(nn.Module): 59 | def __init__(self): 60 | super(GoogLeNet, self).__init__() 61 | self.pre_layers = nn.Sequential( 62 | nn.Conv2d(3, 192, kernel_size=3, padding=1), 63 | nn.BatchNorm2d(192), 64 | nn.ReLU(True), 65 | ) 66 | 67 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) 68 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) 69 | 70 | self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) 71 | 72 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) 73 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) 74 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) 75 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) 76 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) 77 | 78 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) 79 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) 80 | 81 | self.avgpool = nn.AvgPool2d(8, stride=1) 82 | self.linear = nn.Linear(1024, 10) 83 | 84 | def forward(self, x): 85 | out = self.pre_layers(x) 86 | out = self.a3(out) 87 | out = self.b3(out) 88 | out = self.maxpool(out) 89 | out = self.a4(out) 90 | out = self.b4(out) 91 | out = self.c4(out) 92 | out = self.d4(out) 93 | out = self.e4(out) 94 | out = self.maxpool(out) 95 | out = self.a5(out) 96 | out = self.b5(out) 97 | out = self.avgpool(out) 98 | out = out.view(out.size(0), -1) 99 | out = self.linear(out) 100 | return out 101 | 102 | # net = GoogLeNet() 103 | # x = torch.randn(1,3,32,32) 104 | # y = net(Variable(x)) 105 | # print(y.size()) 106 | -------------------------------------------------------------------------------- /models/resnext.py: -------------------------------------------------------------------------------- 1 | '''ResNeXt in PyTorch. 2 | 3 | See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from torch.autograd import Variable 10 | 11 | 12 | class Block(nn.Module): 13 | '''Grouped convolution block.''' 14 | expansion = 2 15 | 16 | def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1): 17 | super(Block, self).__init__() 18 | group_width = cardinality * bottleneck_width 19 | self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(group_width) 21 | self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 22 | self.bn2 = nn.BatchNorm2d(group_width) 23 | self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False) 24 | self.bn3 = nn.BatchNorm2d(self.expansion*group_width) 25 | 26 | self.shortcut = nn.Sequential() 27 | if stride != 1 or in_planes != self.expansion*group_width: 28 | self.shortcut = nn.Sequential( 29 | nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False), 30 | nn.BatchNorm2d(self.expansion*group_width) 31 | ) 32 | 33 | def forward(self, x): 34 | out = F.relu(self.bn1(self.conv1(x))) 35 | out = F.relu(self.bn2(self.conv2(out))) 36 | out = self.bn3(self.conv3(out)) 37 | out += self.shortcut(x) 38 | out = F.relu(out) 39 | return out 40 | 41 | 42 | class ResNeXt(nn.Module): 43 | def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10): 44 | super(ResNeXt, self).__init__() 45 | self.cardinality = cardinality 46 | self.bottleneck_width = bottleneck_width 47 | self.in_planes = 64 48 | 49 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False) 50 | self.bn1 = nn.BatchNorm2d(64) 51 | self.layer1 = self._make_layer(num_blocks[0], 1) 52 | self.layer2 = self._make_layer(num_blocks[1], 2) 53 | self.layer3 = self._make_layer(num_blocks[2], 2) 54 | # self.layer4 = self._make_layer(num_blocks[3], 2) 55 | self.linear = nn.Linear(cardinality*bottleneck_width*8, num_classes) 56 | 57 | def _make_layer(self, num_blocks, stride): 58 | strides = [stride] + [1]*(num_blocks-1) 59 | layers = [] 60 | for stride in strides: 61 | layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride)) 62 | self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width 63 | # Increase bottleneck_width by 2 after each stage. 64 | self.bottleneck_width *= 2 65 | return nn.Sequential(*layers) 66 | 67 | def forward(self, x): 68 | out = F.relu(self.bn1(self.conv1(x))) 69 | out = self.layer1(out) 70 | out = self.layer2(out) 71 | out = self.layer3(out) 72 | # out = self.layer4(out) 73 | out = F.avg_pool2d(out, 8) 74 | out = out.view(out.size(0), -1) 75 | out = self.linear(out) 76 | return out 77 | 78 | 79 | def ResNeXt29_2x64d(num_classes=10): 80 | return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64, num_classes=num_classes) 81 | 82 | def ResNeXt29_4x64d(num_classes=10): 83 | return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64, num_classes=num_classes) 84 | 85 | def ResNeXt29_8x64d(num_classes=10): 86 | return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64, num_classes=num_classes) 87 | 88 | def ResNeXt29_32x4d(num_classes=10): 89 | return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4, num_classes=num_classes) 90 | 91 | def test_resnext(): 92 | net = ResNeXt29_2x64d() 93 | x = torch.randn(1,3,32,32) 94 | y = net(Variable(x)) 95 | print(y.size()) 96 | 97 | # test_resnext() 98 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # InstaHide training on CIFAR-10 with PyTorch 2 | 3 | ## Overview 4 | InstaHide[1] is a practical instance-hiding method for image data encryption in privacy-sensitive distributed deep learning. 5 | 6 | InstaHide uses the Mixup[2] method with a one-time secret key consisting of a pixel-wise random sign-flipping mask and samples from the same training dataset (Inside-dataset InstaHide) or a large public dataset (Cross-dataset InstaHide). It can be easily plugged into an existing distributed learning pipeline, and is very efficient and incurs minor reduction on accuracy. 7 | 8 | We also release a [challenge](https://github.com/Hazelsuko07/InstaHide_Challenge) to further investigate the security of InstaHide. 9 | 10 | 11 | ## Citation 12 | If you use InstaHide or this codebase in your research, then please cite our paper: 13 | ``` 14 | @inproceedings{hsla20, 15 | title = {InstaHide: Instance-hiding Schemes for Private Distributed Learning}, 16 | author = {Yangsibo Huang and Zhao Song and Kai Li and Sanjeev Arora}, 17 | booktitle = {Internation Conference on Machine Learning (ICML)}, 18 | year = {2020} 19 | } 20 | ``` 21 | 22 | ## How to run 23 | ### Install dependencies 24 | - Create an Anaconda environment with Python3.6 25 | ``` 26 | conda create -n instahide python=3.6 27 | ``` 28 | - Run the following command to install dependencies 29 | ``` 30 | conda activate instahide 31 | pip install -r requirements.txt 32 | ``` 33 | ### Important script arguments 34 | Training configurations: 35 | - `model`: network architecture (default: 'ResNet18') 36 | - `lr`: learning rate (default: 0.1) 37 | - `batch-size`: batch size (default: 128) 38 | - `decay`: weight decay (default: 1e-4) 39 | - `no-augment`: turn off data augmentation 40 | 41 | InstaHide configurations: 42 | - `klam`: the number of images got mixed in an instahide encryption, `k` in the paper (default: 4) 43 | - `mode`: 'instahide' or 'mixup' (default: 'instahide') 44 | - `upper`: the upper bound of any coefficient, `c1` in the paper (default: 0.65) 45 | - `dom`: the lower bound of the sum of coefficients of two private images, `c2` in the paper (default: 0.3, *only for Cross-dataset InstaHide*) 46 | 47 | ### Inside-dataset InstaHide: 48 | Inside-dataset Instahide mixes each training image with random images within the same private training dataset. 49 | 50 | For inside-dataset InstaHide training, run the following script: 51 | ``` 52 | python train_inside.py --mode instahide --klam 4 --data cifar10 53 | ``` 54 | 55 | ### Cross-dataset InstaHide: 56 | Cross-dataset Instahide, arguably more secure, involves mixing with random images from a large public dataset. In the paper, we use the unlabelled [ImageNet](http://image-net.org/download)[3] as the public dataset. 57 | 58 | For cross-dataset InstaHide training, first, prepare and preprocess your public dataset, and save it in `PATH/TO/FILTERED_PUBLIC_DATA`. Then, run the following training script: 59 | 60 | ``` 61 | python train_cross.py --mode instahide --klam 6 --data cifar10 --pair --dom 0.3 --help_dir PATH/TO/FILTERED_PUBLIC_DATA 62 | ``` 63 | 64 | ## Try InstaHide on new datasets or your own data? 65 | You can easily customize your own dataloader to test InstaHide on more datasets (see the `train_inside.py` and `train_cross.py`, around the 'Prepare data' section). 66 | 67 | You can also try new models by defining the network architectures under the `\model` folder. 68 | 69 | ## Questions 70 | If you have any questions, please open an issue or contact yangsibo@princeton.edu. 71 | 72 | 73 | 74 | ## References: 75 | [1] [**InstaHide: Instance-hiding Schemes for Private Distributed Learning**](http://arxiv.org/abs/2010.02772), *Yangsibo Huang, Zhao Song, Kai Li, Sanjeev Arora*, ICML 2020 76 | 77 | [2] [**mixup: Beyond Empirical Risk Minimization**](https://arxiv.org/abs/1710.09412), *Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz*, ICLR 2018 78 | 79 | [3] [**ImageNet: A Large-Scale Hierarchical Image Database.**](http://www.image-net.org/papers/imagenet_cvpr09.pdf), *Jia Deng, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, Li Fei-Fei*, CVPR 2009 -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | '''DenseNet in PyTorch.''' 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from torch.autograd import Variable 9 | 10 | 11 | class Bottleneck(nn.Module): 12 | def __init__(self, in_planes, growth_rate): 13 | super(Bottleneck, self).__init__() 14 | self.bn1 = nn.BatchNorm2d(in_planes) 15 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False) 16 | self.bn2 = nn.BatchNorm2d(4*growth_rate) 17 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 18 | 19 | def forward(self, x): 20 | out = self.conv1(F.relu(self.bn1(x))) 21 | out = self.conv2(F.relu(self.bn2(out))) 22 | out = torch.cat([out,x], 1) 23 | return out 24 | 25 | 26 | class Transition(nn.Module): 27 | def __init__(self, in_planes, out_planes): 28 | super(Transition, self).__init__() 29 | self.bn = nn.BatchNorm2d(in_planes) 30 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 31 | 32 | def forward(self, x): 33 | out = self.conv(F.relu(self.bn(x))) 34 | out = F.avg_pool2d(out, 2) 35 | return out 36 | 37 | 38 | class DenseNet(nn.Module): 39 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10): 40 | super(DenseNet, self).__init__() 41 | self.growth_rate = growth_rate 42 | 43 | num_planes = 2*growth_rate 44 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 45 | 46 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 47 | num_planes += nblocks[0]*growth_rate 48 | out_planes = int(math.floor(num_planes*reduction)) 49 | self.trans1 = Transition(num_planes, out_planes) 50 | num_planes = out_planes 51 | 52 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 53 | num_planes += nblocks[1]*growth_rate 54 | out_planes = int(math.floor(num_planes*reduction)) 55 | self.trans2 = Transition(num_planes, out_planes) 56 | num_planes = out_planes 57 | 58 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 59 | num_planes += nblocks[2]*growth_rate 60 | out_planes = int(math.floor(num_planes*reduction)) 61 | self.trans3 = Transition(num_planes, out_planes) 62 | num_planes = out_planes 63 | 64 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 65 | num_planes += nblocks[3]*growth_rate 66 | 67 | self.bn = nn.BatchNorm2d(num_planes) 68 | self.linear = nn.Linear(num_planes, num_classes) 69 | 70 | def _make_dense_layers(self, block, in_planes, nblock): 71 | layers = [] 72 | for i in range(nblock): 73 | layers.append(block(in_planes, self.growth_rate)) 74 | in_planes += self.growth_rate 75 | return nn.Sequential(*layers) 76 | 77 | def forward(self, x): 78 | out = self.conv1(x) 79 | out = self.trans1(self.dense1(out)) 80 | out = self.trans2(self.dense2(out)) 81 | out = self.trans3(self.dense3(out)) 82 | out = self.dense4(out) 83 | out = F.avg_pool2d(F.relu(self.bn(out)), 4) 84 | out = out.view(out.size(0), -1) 85 | out = self.linear(out) 86 | return out 87 | 88 | def DenseNet121(): 89 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32) 90 | 91 | def DenseNet169(): 92 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32) 93 | 94 | def DenseNet201(): 95 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32) 96 | 97 | def DenseNet161(): 98 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48) 99 | 100 | def densenet_cifar(): 101 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12) 102 | 103 | def test_densenet(): 104 | net = densenet_cifar() 105 | x = torch.randn(1,3,32,32) 106 | y = net(Variable(x)) 107 | print(y) 108 | 109 | # test_densenet() 110 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | '''Some helper functions for PyTorch, including: 2 | - get_mean_and_std: calculate the mean and std value of dataset. 3 | - msr_init: net parameter initialization. 4 | - progress_bar: progress bar mimic xlua.progress. 5 | ''' 6 | import os 7 | import sys 8 | import time 9 | import math 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | import matplotlib.pyplot as plt 15 | 16 | def get_mean_and_std(dataset): 17 | '''Compute the mean and std value of dataset.''' 18 | dataloader = torch.utils.data.DataLoader( 19 | dataset, batch_size=1, shuffle=True, num_workers=2) 20 | mean = torch.zeros(3) 21 | std = torch.zeros(3) 22 | print('==> Computing mean and std..') 23 | for inputs, targets in dataloader: 24 | for i in range(3): 25 | mean[i] += inputs[:, i, :, :].mean() 26 | std[i] += inputs[:, i, :, :].std() 27 | mean.div_(len(dataset)) 28 | std.div_(len(dataset)) 29 | return mean, std 30 | 31 | 32 | def init_params(net): 33 | '''Init layer parameters.''' 34 | for m in net.modules(): 35 | if isinstance(m, nn.Conv2d): 36 | init.kaiming_normal(m.weight, mode='fan_out') 37 | if m.bias: 38 | init.constant(m.bias, 0) 39 | elif isinstance(m, nn.BatchNorm2d): 40 | init.constant(m.weight, 1) 41 | init.constant(m.bias, 0) 42 | elif isinstance(m, nn.Linear): 43 | init.normal(m.weight, std=1e-3) 44 | if m.bias: 45 | init.constant(m.bias, 0) 46 | 47 | 48 | _, term_width = os.popen('stty size', 'r').read().split() 49 | term_width = int(term_width) 50 | 51 | TOTAL_BAR_LENGTH = 86. 52 | last_time = time.time() 53 | begin_time = last_time 54 | 55 | 56 | def progress_bar(current, total, msg=None): 57 | global last_time, begin_time 58 | if current == 0: 59 | begin_time = time.time() # Reset for new bar. 60 | 61 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 62 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 63 | 64 | sys.stdout.write(' [') 65 | for i in range(cur_len): 66 | sys.stdout.write('=') 67 | sys.stdout.write('>') 68 | for i in range(rest_len): 69 | sys.stdout.write('.') 70 | sys.stdout.write(']') 71 | 72 | cur_time = time.time() 73 | step_time = cur_time - last_time 74 | last_time = cur_time 75 | tot_time = cur_time - begin_time 76 | 77 | L = [] 78 | L.append(' Step: %s' % format_time(step_time)) 79 | L.append(' | Tot: %s' % format_time(tot_time)) 80 | if msg: 81 | L.append(' | ' + msg) 82 | 83 | msg = ''.join(L) 84 | sys.stdout.write(msg) 85 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 86 | sys.stdout.write(' ') 87 | 88 | # Go back to the center of the bar. 89 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)): 90 | sys.stdout.write('\b') 91 | sys.stdout.write(' %d/%d ' % (current+1, total)) 92 | 93 | if current < total-1: 94 | sys.stdout.write('\r') 95 | else: 96 | sys.stdout.write('\n') 97 | sys.stdout.flush() 98 | 99 | 100 | def format_time(seconds): 101 | days = int(seconds / 3600/24) 102 | seconds = seconds - days*3600*24 103 | hours = int(seconds / 3600) 104 | seconds = seconds - hours*3600 105 | minutes = int(seconds / 60) 106 | seconds = seconds - minutes*60 107 | secondsf = int(seconds) 108 | seconds = seconds - secondsf 109 | millis = int(seconds*1000) 110 | 111 | f = '' 112 | i = 1 113 | if days > 0: 114 | f += str(days) + 'D' 115 | i += 1 116 | if hours > 0 and i <= 2: 117 | f += str(hours) + 'h' 118 | i += 1 119 | if minutes > 0 and i <= 2: 120 | f += str(minutes) + 'm' 121 | i += 1 122 | if secondsf > 0 and i <= 2: 123 | f += str(secondsf) + 's' 124 | i += 1 125 | if millis > 0 and i <= 2: 126 | f += str(millis) + 'ms' 127 | i += 1 128 | if f == '': 129 | f = '0ms' 130 | return f 131 | 132 | 133 | def chunks(lst, n): 134 | for i in range(0, len(lst), n): 135 | yield lst[i:i + n] 136 | 137 | 138 | def save_fig(img_tensor, fname): 139 | img_arr = img_tensor.cpu().detach().permute(1, 2, 0) 140 | img_arr = (img_arr - img_arr.min())/(img_arr.max() - img_arr.min()) 141 | plt.imshow(img_arr) 142 | plt.axis('off') 143 | plt.savefig(fname, bbox_inches='tight') 144 | plt.show() -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | 3 | BasicBlock and Bottleneck module is from the original ResNet paper: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | 7 | PreActBlock and PreActBottleneck module is from the later paper: 8 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 9 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 10 | ''' 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | from torch.autograd import Variable 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, in_planes, planes, stride=1): 26 | super(BasicBlock, self).__init__() 27 | self.conv1 = conv3x3(in_planes, planes, stride) 28 | self.bn1 = nn.BatchNorm2d(planes) 29 | self.conv2 = conv3x3(planes, planes) 30 | self.bn2 = nn.BatchNorm2d(planes) 31 | 32 | self.shortcut = nn.Sequential() 33 | if stride != 1 or in_planes != self.expansion*planes: 34 | self.shortcut = nn.Sequential( 35 | nn.Conv2d(in_planes, self.expansion*planes, 36 | kernel_size=1, stride=stride, bias=False), 37 | nn.BatchNorm2d(self.expansion*planes) 38 | ) 39 | 40 | def forward(self, x): 41 | out = F.relu(self.bn1(self.conv1(x))) 42 | out = self.bn2(self.conv2(out)) 43 | out += self.shortcut(x) 44 | out = F.relu(out) 45 | return out 46 | 47 | 48 | class PreActBlock(nn.Module): 49 | '''Pre-activation version of the BasicBlock.''' 50 | expansion = 1 51 | 52 | def __init__(self, in_planes, planes, stride=1): 53 | super(PreActBlock, self).__init__() 54 | self.bn1 = nn.BatchNorm2d(in_planes) 55 | self.conv1 = conv3x3(in_planes, planes, stride) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.conv2 = conv3x3(planes, planes) 58 | 59 | self.shortcut = nn.Sequential() 60 | if stride != 1 or in_planes != self.expansion*planes: 61 | self.shortcut = nn.Sequential( 62 | nn.Conv2d(in_planes, self.expansion*planes, 63 | kernel_size=1, stride=stride, bias=False) 64 | ) 65 | 66 | def forward(self, x): 67 | out = F.relu(self.bn1(x)) 68 | shortcut = self.shortcut(out) 69 | out = self.conv1(out) 70 | out = self.conv2(F.relu(self.bn2(out))) 71 | out += shortcut 72 | return out 73 | 74 | 75 | class Bottleneck(nn.Module): 76 | expansion = 4 77 | 78 | def __init__(self, in_planes, planes, stride=1): 79 | super(Bottleneck, self).__init__() 80 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 81 | self.bn1 = nn.BatchNorm2d(planes) 82 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 83 | stride=stride, padding=1, bias=False) 84 | self.bn2 = nn.BatchNorm2d(planes) 85 | self.conv3 = nn.Conv2d(planes, self.expansion * 86 | planes, kernel_size=1, bias=False) 87 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 88 | 89 | self.shortcut = nn.Sequential() 90 | if stride != 1 or in_planes != self.expansion*planes: 91 | self.shortcut = nn.Sequential( 92 | nn.Conv2d(in_planes, self.expansion*planes, 93 | kernel_size=1, stride=stride, bias=False), 94 | nn.BatchNorm2d(self.expansion*planes) 95 | ) 96 | 97 | def forward(self, x): 98 | out = F.relu(self.bn1(self.conv1(x))) 99 | out = F.relu(self.bn2(self.conv2(out))) 100 | out = self.bn3(self.conv3(out)) 101 | out += self.shortcut(x) 102 | out = F.relu(out) 103 | return out 104 | 105 | 106 | class PreActBottleneck(nn.Module): 107 | '''Pre-activation version of the original Bottleneck module.''' 108 | expansion = 4 109 | 110 | def __init__(self, in_planes, planes, stride=1): 111 | super(PreActBottleneck, self).__init__() 112 | self.bn1 = nn.BatchNorm2d(in_planes) 113 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 114 | self.bn2 = nn.BatchNorm2d(planes) 115 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, 116 | stride=stride, padding=1, bias=False) 117 | self.bn3 = nn.BatchNorm2d(planes) 118 | self.conv3 = nn.Conv2d(planes, self.expansion * 119 | planes, kernel_size=1, bias=False) 120 | 121 | self.shortcut = nn.Sequential() 122 | if stride != 1 or in_planes != self.expansion*planes: 123 | self.shortcut = nn.Sequential( 124 | nn.Conv2d(in_planes, self.expansion*planes, 125 | kernel_size=1, stride=stride, bias=False) 126 | ) 127 | 128 | def forward(self, x): 129 | out = F.relu(self.bn1(x)) 130 | shortcut = self.shortcut(out) 131 | out = self.conv1(out) 132 | out = self.conv2(F.relu(self.bn2(out))) 133 | out = self.conv3(F.relu(self.bn3(out))) 134 | out += shortcut 135 | return out 136 | 137 | 138 | class ResNet(nn.Module): 139 | def __init__(self, block, num_blocks, num_classes=10): 140 | super(ResNet, self).__init__() 141 | self.in_planes = 64 142 | 143 | self.conv1 = conv3x3(3, 64) 144 | self.bn1 = nn.BatchNorm2d(64) 145 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 146 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 147 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 148 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 149 | self.linear = nn.Linear(512*block.expansion, num_classes) 150 | 151 | def _make_layer(self, block, planes, num_blocks, stride): 152 | strides = [stride] + [1]*(num_blocks-1) 153 | layers = [] 154 | for stride in strides: 155 | layers.append(block(self.in_planes, planes, stride)) 156 | self.in_planes = planes * block.expansion 157 | return nn.Sequential(*layers) 158 | 159 | def forward(self, x, lin=0, lout=5): 160 | out = x 161 | if lin < 1 and lout > -1: 162 | out = self.conv1(out) 163 | out = self.bn1(out) 164 | out = F.relu(out) 165 | if lin < 2 and lout > 0: 166 | out = self.layer1(out) 167 | if lin < 3 and lout > 1: 168 | out = self.layer2(out) 169 | if lin < 4 and lout > 2: 170 | out = self.layer3(out) 171 | if lin < 5 and lout > 3: 172 | out = self.layer4(out) 173 | if lout > 4: 174 | out = F.avg_pool2d(out, 4) 175 | out = out.view(out.size(0), -1) 176 | out = self.linear(out) 177 | return out 178 | 179 | 180 | def ResNet18(num_classes=10): 181 | return ResNet(PreActBlock, [2, 2, 2, 2], num_classes=num_classes) 182 | 183 | 184 | def ResNet34(num_classes=10): 185 | return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes) 186 | 187 | 188 | def ResNet50(num_classes=10): 189 | return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes) 190 | 191 | 192 | def ResNet101(num_classes=10): 193 | return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes) 194 | 195 | 196 | def ResNet152(num_classes=10): 197 | return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes) 198 | 199 | 200 | def test(): 201 | net = ResNet18() 202 | y = net(Variable(torch.randn(1, 3, 32, 32))) 203 | print(y.size()) 204 | 205 | # test() 206 | -------------------------------------------------------------------------------- /models/nasnet.py: -------------------------------------------------------------------------------- 1 | """nasnet in pytorch 2 | [1] Barret Zoph, Vijay Vasudevan, Jonathon Shlens, Quoc V. Le 3 | Learning Transferable Architectures for Scalable Image Recognition 4 | https://arxiv.org/abs/1707.07012 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class SeperableConv2d(nn.Module): 12 | 13 | def __init__(self, input_channels, output_channels, kernel_size, **kwargs): 14 | 15 | super().__init__() 16 | self.depthwise = nn.Conv2d( 17 | input_channels, 18 | input_channels, 19 | kernel_size, 20 | groups=input_channels, 21 | **kwargs 22 | ) 23 | 24 | self.pointwise = nn.Conv2d( 25 | input_channels, 26 | output_channels, 27 | 1 28 | ) 29 | 30 | def forward(self, x): 31 | x = self.depthwise(x) 32 | x = self.pointwise(x) 33 | 34 | return x 35 | 36 | 37 | class SeperableBranch(nn.Module): 38 | 39 | def __init__(self, input_channels, output_channels, kernel_size, **kwargs): 40 | """Adds 2 blocks of [relu-separable conv-batchnorm].""" 41 | super().__init__() 42 | self.block1 = nn.Sequential( 43 | nn.ReLU(), 44 | SeperableConv2d(input_channels, output_channels, 45 | kernel_size, **kwargs), 46 | nn.BatchNorm2d(output_channels) 47 | ) 48 | 49 | self.block2 = nn.Sequential( 50 | nn.ReLU(), 51 | SeperableConv2d(output_channels, output_channels, 52 | kernel_size, stride=1, padding=int(kernel_size / 2)), 53 | nn.BatchNorm2d(output_channels) 54 | ) 55 | 56 | def forward(self, x): 57 | x = self.block1(x) 58 | x = self.block2(x) 59 | 60 | return x 61 | 62 | 63 | class Fit(nn.Module): 64 | """Make the cell outputs compatible 65 | Args: 66 | prev_filters: filter number of tensor prev, needs to be modified 67 | filters: filter number of normal cell branch output filters 68 | """ 69 | 70 | def __init__(self, prev_filters, filters): 71 | super().__init__() 72 | self.relu = nn.ReLU() 73 | 74 | self.p1 = nn.Sequential( 75 | nn.AvgPool2d(1, stride=2), 76 | nn.Conv2d(prev_filters, int(filters / 2), 1) 77 | ) 78 | 79 | # make sure there is no information loss 80 | self.p2 = nn.Sequential( 81 | nn.ConstantPad2d((0, 1, 0, 1), 0), 82 | nn.ConstantPad2d((-1, 0, -1, 0), 0), # cropping 83 | nn.AvgPool2d(1, stride=2), 84 | nn.Conv2d(prev_filters, int(filters / 2), 1) 85 | ) 86 | 87 | self.bn = nn.BatchNorm2d(filters) 88 | 89 | self.dim_reduce = nn.Sequential( 90 | nn.ReLU(), 91 | nn.Conv2d(prev_filters, filters, 1), 92 | nn.BatchNorm2d(filters) 93 | ) 94 | 95 | self.filters = filters 96 | 97 | def forward(self, inputs): 98 | x, prev = inputs 99 | if prev is None: 100 | return x 101 | 102 | # image size does not match 103 | elif x.size(2) != prev.size(2): 104 | prev = self.relu(prev) 105 | p1 = self.p1(prev) 106 | p2 = self.p2(prev) 107 | prev = torch.cat([p1, p2], 1) 108 | prev = self.bn(prev) 109 | 110 | elif prev.size(1) != self.filters: 111 | prev = self.dim_reduce(prev) 112 | 113 | return prev 114 | 115 | 116 | class NormalCell(nn.Module): 117 | 118 | def __init__(self, x_in, prev_in, output_channels): 119 | super().__init__() 120 | 121 | self.dem_reduce = nn.Sequential( 122 | nn.ReLU(), 123 | nn.Conv2d(x_in, output_channels, 1, bias=False), 124 | nn.BatchNorm2d(output_channels) 125 | ) 126 | 127 | self.block1_left = SeperableBranch( 128 | output_channels, 129 | output_channels, 130 | kernel_size=3, 131 | padding=1, 132 | bias=False 133 | ) 134 | self.block1_right = nn.Sequential() 135 | 136 | self.block2_left = SeperableBranch( 137 | output_channels, 138 | output_channels, 139 | kernel_size=3, 140 | padding=1, 141 | bias=False 142 | ) 143 | self.block2_right = SeperableBranch( 144 | output_channels, 145 | output_channels, 146 | kernel_size=5, 147 | padding=2, 148 | bias=False 149 | ) 150 | 151 | self.block3_left = nn.AvgPool2d(3, stride=1, padding=1) 152 | self.block3_right = nn.Sequential() 153 | 154 | self.block4_left = nn.AvgPool2d(3, stride=1, padding=1) 155 | self.block4_right = nn.AvgPool2d(3, stride=1, padding=1) 156 | 157 | self.block5_left = SeperableBranch( 158 | output_channels, 159 | output_channels, 160 | kernel_size=5, 161 | padding=2, 162 | bias=False 163 | ) 164 | self.block5_right = SeperableBranch( 165 | output_channels, 166 | output_channels, 167 | kernel_size=3, 168 | padding=1, 169 | bias=False 170 | ) 171 | 172 | self.fit = Fit(prev_in, output_channels) 173 | 174 | def forward(self, x): 175 | x, prev = x 176 | 177 | # return transformed x as new x, and original x as prev 178 | # only prev tensor needs to be modified 179 | prev = self.fit((x, prev)) 180 | 181 | h = self.dem_reduce(x) 182 | 183 | x1 = self.block1_left(h) + self.block1_right(h) 184 | x2 = self.block2_left(prev) + self.block2_right(h) 185 | x3 = self.block3_left(h) + self.block3_right(h) 186 | x4 = self.block4_left(prev) + self.block4_right(prev) 187 | x5 = self.block5_left(prev) + self.block5_right(prev) 188 | 189 | return torch.cat([prev, x1, x2, x3, x4, x5], 1), x 190 | 191 | 192 | class ReductionCell(nn.Module): 193 | 194 | def __init__(self, x_in, prev_in, output_channels): 195 | super().__init__() 196 | 197 | self.dim_reduce = nn.Sequential( 198 | nn.ReLU(), 199 | nn.Conv2d(x_in, output_channels, 1), 200 | nn.BatchNorm2d(output_channels) 201 | ) 202 | 203 | # block1 204 | self.layer1block1_left = SeperableBranch( 205 | output_channels, output_channels, 7, stride=2, padding=3) 206 | self.layer1block1_right = SeperableBranch( 207 | output_channels, output_channels, 5, stride=2, padding=2) 208 | 209 | # block2 210 | self.layer1block2_left = nn.MaxPool2d(3, stride=2, padding=1) 211 | self.layer1block2_right = SeperableBranch( 212 | output_channels, output_channels, 7, stride=2, padding=3) 213 | 214 | # block3 215 | self.layer1block3_left = nn.AvgPool2d(3, 2, 1) 216 | self.layer1block3_right = SeperableBranch( 217 | output_channels, output_channels, 5, stride=2, padding=2) 218 | 219 | # block5 220 | self.layer2block1_left = nn.MaxPool2d(3, 2, 1) 221 | self.layer2block1_right = SeperableBranch( 222 | output_channels, output_channels, 3, stride=1, padding=1) 223 | 224 | # block4 225 | self.layer2block2_left = nn.AvgPool2d(3, 1, 1) 226 | self.layer2block2_right = nn.Sequential() 227 | 228 | self.fit = Fit(prev_in, output_channels) 229 | 230 | def forward(self, x): 231 | x, prev = x 232 | prev = self.fit((x, prev)) 233 | 234 | h = self.dim_reduce(x) 235 | 236 | layer1block1 = self.layer1block1_left( 237 | prev) + self.layer1block1_right(h) 238 | layer1block2 = self.layer1block2_left( 239 | h) + self.layer1block2_right(prev) 240 | layer1block3 = self.layer1block3_left( 241 | h) + self.layer1block3_right(prev) 242 | layer2block1 = self.layer2block1_left( 243 | h) + self.layer2block1_right(layer1block1) 244 | layer2block2 = self.layer2block2_left( 245 | layer1block1) + self.layer2block2_right(layer1block2) 246 | 247 | return torch.cat([ 248 | layer1block2, # https://github.com/keras-team/keras-applications/blob/master/keras_applications/nasnet.py line 739 249 | layer1block3, 250 | layer2block1, 251 | layer2block2 252 | ], 1), x 253 | 254 | 255 | class NasNetA(nn.Module): 256 | 257 | def __init__(self, repeat_cell_num, reduction_num, filters, stemfilter, num_classes=100): 258 | super().__init__() 259 | 260 | self.stem = nn.Sequential( 261 | nn.Conv2d(3, stemfilter, 3, padding=1, bias=False), 262 | nn.BatchNorm2d(stemfilter) 263 | ) 264 | 265 | self.prev_filters = stemfilter 266 | self.x_filters = stemfilter 267 | self.filters = filters 268 | 269 | self.cell_layers = self._make_layers(repeat_cell_num, reduction_num) 270 | 271 | self.relu = nn.ReLU() 272 | self.avg = nn.AdaptiveAvgPool2d(1) 273 | self.fc = nn.Linear(self.filters * 6, num_classes) 274 | 275 | def _make_normal(self, block, repeat, output): 276 | """make normal cell 277 | Args: 278 | block: cell type 279 | repeat: number of repeated normal cell 280 | output: output filters for each branch in normal cell 281 | Returns: 282 | stacked normal cells 283 | """ 284 | 285 | layers = [] 286 | for r in range(repeat): 287 | layers.append(block(self.x_filters, self.prev_filters, output)) 288 | self.prev_filters = self.x_filters 289 | self.x_filters = output * 6 # concatenate 6 branches 290 | 291 | return layers 292 | 293 | def _make_reduction(self, block, output): 294 | """make normal cell 295 | Args: 296 | block: cell type 297 | output: output filters for each branch in reduction cell 298 | Returns: 299 | reduction cell 300 | """ 301 | 302 | reduction = block(self.x_filters, self.prev_filters, output) 303 | self.prev_filters = self.x_filters 304 | self.x_filters = output * 4 # stack for 4 branches 305 | 306 | return reduction 307 | 308 | def _make_layers(self, repeat_cell_num, reduction_num): 309 | 310 | layers = [] 311 | for i in range(reduction_num): 312 | 313 | layers.extend(self._make_normal( 314 | NormalCell, repeat_cell_num, self.filters)) 315 | self.filters *= 2 316 | layers.append(self._make_reduction(ReductionCell, self.filters)) 317 | 318 | layers.extend(self._make_normal( 319 | NormalCell, repeat_cell_num, self.filters)) 320 | 321 | return nn.Sequential(*layers) 322 | 323 | def forward(self, x): 324 | 325 | x = self.stem(x) 326 | prev = None 327 | x, prev = self.cell_layers((x, prev)) 328 | x = self.relu(x) 329 | x = self.avg(x) 330 | x = x.view(x.size(0), -1) 331 | x = self.fc(x) 332 | 333 | return x 334 | 335 | 336 | def nasnet(num_classes=100): 337 | 338 | # stem filters must be 44, it's a pytorch workaround, cant change to other number 339 | return NasNetA(4, 2, 44, 44, num_classes=num_classes) 340 | -------------------------------------------------------------------------------- /train_inside.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import csv 7 | import os 8 | 9 | import numpy as np 10 | 11 | import torch 12 | from torch.autograd import Variable 13 | import torch.backends.cudnn as cudnn 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | import torchvision.transforms as transforms 17 | import torchvision.datasets as datasets 18 | import torch.nn.functional as F 19 | import random 20 | 21 | import models 22 | from utils import progress_bar, chunks, save_fig 23 | 24 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 25 | 26 | parser = argparse.ArgumentParser( 27 | description='PyTorch InstaHide Training, CIFAR-10') 28 | 29 | # Training configurations 30 | parser.add_argument('--model', 31 | default="ResNet18", 32 | type=str, 33 | help='model type (default: ResNet18)') 34 | parser.add_argument('--data', default='cifar10', type=str, 35 | help='dataset') 36 | parser.add_argument('--nclass', default=10, type=int, 37 | help='number of classes') 38 | 39 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 40 | parser.add_argument('--batch-size', default=128, type=int, help='batch size') 41 | parser.add_argument('--epoch', 42 | default=200, 43 | type=int, 44 | help='total epochs to run') 45 | parser.add_argument('--no-augment', 46 | dest='augment', 47 | action='store_false', 48 | help='use standard augmentation (default: True)') 49 | parser.add_argument('--decay', default=1e-4, type=float, help='weight decay') 50 | 51 | # Saving configurations 52 | parser.add_argument('--name', default='inside', type=str, help='name of run') 53 | parser.add_argument('--seed', default=0, type=int, help='random seed') 54 | parser.add_argument('--resume', 55 | '-r', 56 | action='store_true', 57 | help='resume from checkpoint') 58 | 59 | # InstaHide configurations 60 | parser.add_argument('--klam', default=4, type=int, help='number of lambdas') 61 | parser.add_argument('--mode', default='instahide', 62 | type=str, help='InsatHide or Mixup') 63 | parser.add_argument('--upper', default=0.65, type=float, help='the upper bound of any coefficient') 64 | 65 | 66 | args = parser.parse_args() 67 | use_cuda = torch.cuda.is_available() 68 | device = torch.device("cuda" if use_cuda else "cpu") 69 | 70 | criterion = nn.CrossEntropyLoss() 71 | best_acc = 0 # best test accuracy 72 | 73 | ## --------------- Functions for train & eval --------------- ## 74 | 75 | 76 | def label_to_onehot(target, num_classes=args.nclass): 77 | '''Returns one-hot embeddings of scaler labels''' 78 | target = torch.unsqueeze(target, 1) 79 | onehot_target = torch.zeros(target.size( 80 | 0), num_classes, device=target.device) 81 | onehot_target.scatter_(1, target, 1) 82 | return onehot_target 83 | 84 | 85 | def cross_entropy_for_onehot(pred, target): 86 | return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1)) 87 | 88 | 89 | def mixup_criterion(pred, ys, lam_batch, num_class=args.nclass): 90 | '''Returns mixup loss''' 91 | ys_onehot = [label_to_onehot(y, num_classes=num_class) for y in ys] 92 | mixy = vec_mul_ten(lam_batch[:, 0], ys_onehot[0]) 93 | for i in range(1, args.klam): 94 | mixy += vec_mul_ten(lam_batch[:, i], ys_onehot[i]) 95 | l = cross_entropy_for_onehot(pred, mixy) 96 | return l 97 | 98 | 99 | def vec_mul_ten(vec, tensor): 100 | size = list(tensor.size()) 101 | size[0] = -1 102 | size_rs = [1 for i in range(len(size))] 103 | size_rs[0] = -1 104 | vec = vec.reshape(size_rs).expand(size) 105 | res = vec * tensor 106 | return res 107 | 108 | 109 | def mixup_data(x, y, use_cuda=True): 110 | '''Returns mixed inputs, lists of targets, and lambdas''' 111 | lams = np.random.normal(0, 1, size=(x.size()[0], args.klam)) 112 | for i in range(x.size()[0]): 113 | lams[i] = np.abs(lams[i]) / np.sum(np.abs(lams[i])) 114 | if args.klam > 1: 115 | while lams[i].max() > args.upper: # upper bounds a single lambda 116 | lams[i] = np.random.normal(0, 1, size=(1, args.klam)) 117 | lams[i] = np.abs(lams[i]) / np.sum(np.abs(lams[i])) 118 | 119 | lams = torch.from_numpy(lams).float().to(device) 120 | 121 | mixed_x = vec_mul_ten(lams[:, 0], x) 122 | ys = [y] 123 | 124 | for i in range(1, args.klam): 125 | batch_size = x.size()[0] 126 | index = torch.randperm(batch_size).to(device) 127 | mixed_x += vec_mul_ten(lams[:, i], x[index, :]) 128 | ys.append(y[index]) 129 | 130 | if args.mode == 'instahide': 131 | sign = torch.randint(2, size=list(x.shape), device=device) * 2.0 - 1 132 | mixed_x *= sign.float().to(device) 133 | return mixed_x, ys, lams 134 | 135 | 136 | def generate_sample(trainloader): 137 | assert len(trainloader) == 1 # Load all training data once 138 | for _, (inputs, targets) in enumerate(trainloader): 139 | if use_cuda: 140 | inputs, targets = inputs.cuda(), targets.cuda() 141 | mix_inputs, mix_targets, lams = mixup_data( 142 | inputs, targets.float(), use_cuda) 143 | return (mix_inputs, mix_targets, lams) 144 | 145 | 146 | def train(net, optimizer, inputs_all, mix_targets_all, lams, epoch): 147 | print('\nEpoch: %d' % epoch) 148 | net.train() 149 | train_loss, correct, total = 0, 0, 0 150 | 151 | seq = random.sample(range(len(inputs_all)), len(inputs_all)) 152 | bl = list(chunks(seq, args.batch_size)) 153 | 154 | for batch_idx in range(len(bl)): 155 | b = bl[batch_idx] 156 | inputs = torch.stack([inputs_all[i] for i in b]) 157 | if args.mode == 'instahide' or args.mode == 'mixup': 158 | lam_batch = torch.stack([lams[i] for i in b]) 159 | 160 | mix_targets = [] 161 | for ik in range(args.klam): 162 | mix_targets.append( 163 | torch.stack( 164 | [mix_targets_all[ik][ib].long().to(device) for ib in b])) 165 | targets_var = [Variable(mix_targets[ik]) for ik in range(args.klam)] 166 | 167 | inputs = Variable(inputs) 168 | outputs = net(inputs) 169 | loss = mixup_criterion(outputs, targets_var, lam_batch) 170 | train_loss += loss.data.item() 171 | total += args.batch_size 172 | optimizer.zero_grad() 173 | loss.backward() 174 | optimizer.step() 175 | 176 | progress_bar(batch_idx, len(inputs_all)/args.batch_size+1, 177 | 'Loss: %.3f' % (train_loss / (batch_idx + 1))) 178 | return (train_loss / batch_idx, 100. * correct / total) 179 | 180 | 181 | def test(net, optimizer, testloader, epoch, start_epoch): 182 | global best_acc 183 | net.eval() 184 | test_loss, correct_1, correct_5, total = 0, 0, 0, 0 185 | with torch.no_grad(): 186 | for batch_idx, (inputs, targets) in enumerate(testloader): 187 | if use_cuda: 188 | inputs, targets = inputs.cuda(), targets.cuda() 189 | inputs, targets = Variable(inputs), Variable(targets) 190 | outputs = net(inputs) 191 | loss = criterion(outputs, targets) 192 | 193 | test_loss += loss.data.item() 194 | _, pred = outputs.topk(5, 1, largest=True, sorted=True) 195 | total += targets.size(0) 196 | correct = pred.eq(targets.view(targets.size(0), - 197 | 1).expand_as(pred)).float().cpu() 198 | correct_1 += correct[:, :1].sum() 199 | correct_5 += correct[:, :5].sum() 200 | 201 | progress_bar( 202 | batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % 203 | (test_loss / 204 | (batch_idx + 1), 100. * correct_1 / total, correct_1, total)) 205 | 206 | acc = 100. * correct_1 / total 207 | if epoch == start_epoch + args.epoch - 1 or acc > best_acc: 208 | save_checkpoint(net, acc, epoch) 209 | if acc > best_acc: 210 | best_acc = acc 211 | return (test_loss / batch_idx, 100. * correct_1 / total) 212 | 213 | 214 | def save_checkpoint(net, acc, epoch): 215 | """ Save checkpoints. """ 216 | print('Saving..') 217 | state = { 218 | 'net': net, 219 | 'acc': acc, 220 | 'epoch': epoch, 221 | 'rng_state': torch.get_rng_state() 222 | } 223 | if not os.path.isdir('checkpoint'): 224 | os.mkdir('checkpoint') 225 | ckptname = os.path.join( 226 | './checkpoint/', f'{args.model}_{args.data}_{args.mode}_{args.klam}_{args.name}_{args.seed}.t7') 227 | torch.save(state, ckptname) 228 | 229 | 230 | def adjust_learning_rate(optimizer, epoch): 231 | """ Decrease learning rate at certain epochs. """ 232 | lr = args.lr 233 | if args.data == 'cifar10': 234 | if epoch >= 100: 235 | lr /= 10 236 | if epoch >= 150: 237 | lr /= 10 238 | for param_group in optimizer.param_groups: 239 | param_group['lr'] = lr 240 | 241 | 242 | def main(): 243 | global best_acc 244 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 245 | 246 | if args.seed != 0: 247 | torch.manual_seed(args.seed) 248 | np.random.seed(args.seed) 249 | 250 | print('==> Number of lambdas: %g' % args.klam) 251 | 252 | ## --------------- Prepare data --------------- ## 253 | print('==> Preparing data..') 254 | 255 | cifar_normalize = transforms.Normalize( 256 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 257 | 258 | if args.augment: 259 | transform_cifar_train = transforms.Compose([ 260 | transforms.RandomCrop(32, padding=4), 261 | transforms.RandomHorizontalFlip(), 262 | transforms.ToTensor(), 263 | cifar_normalize 264 | ]) 265 | else: 266 | transform_cifar_train = transforms.Compose([ 267 | transforms.ToTensor(), 268 | cifar_normalize 269 | ]) 270 | 271 | transform_cifar_test = transforms.Compose([ 272 | transforms.ToTensor(), 273 | cifar_normalize 274 | ]) 275 | 276 | if args.data == 'cifar10': 277 | trainset = datasets.CIFAR10(root='./data', 278 | train=True, 279 | download=True, 280 | transform=transform_cifar_train) 281 | testset = datasets.CIFAR10(root='./data', 282 | train=False, 283 | download=True, 284 | transform=transform_cifar_test) 285 | num_class = 10 286 | # You can add your own dataloader and preprocessor here. 287 | 288 | trainloader = torch.utils.data.DataLoader(trainset, 289 | batch_size=len(trainset), 290 | shuffle=True, 291 | num_workers=8) 292 | 293 | testloader = torch.utils.data.DataLoader(testset, 294 | batch_size=args.batch_size, 295 | shuffle=False, 296 | num_workers=8) 297 | 298 | ## --------------- Create the model --------------- ## 299 | if args.resume: 300 | # Load checkpoint. 301 | print('==> Resuming from checkpoint..') 302 | assert os.path.isdir( 303 | 'checkpoint'), 'Error: no checkpoint directory found!' 304 | checkpoint = torch.load('./checkpoint/' + args.data + '_' + 305 | args.name + 'ckpt.t7') 306 | net = checkpoint['net'] 307 | best_acc = checkpoint['acc'] 308 | start_epoch = checkpoint['epoch'] + 1 309 | rng_state = checkpoint['rng_state'] 310 | torch.set_rng_state(rng_state) 311 | else: 312 | print('==> Building model..') 313 | net = models.__dict__[args.model](num_classes=num_class) 314 | 315 | if not os.path.isdir('results'): 316 | os.mkdir('results') 317 | logname = f'results/log_{args.model}_{args.data}_{args.mode}_{args.klam}_{args.name}_{args.seed}.csv' 318 | 319 | if use_cuda: 320 | net.cuda() 321 | net = torch.nn.DataParallel(net) 322 | cudnn.benchmark = True 323 | print('==> Using CUDA..') 324 | 325 | optimizer = optim.SGD(net.parameters(), 326 | lr=args.lr, 327 | momentum=0.9, 328 | weight_decay=args.decay) 329 | 330 | ## --------------- Train and Eval --------------- ## 331 | if not os.path.exists(logname): 332 | with open(logname, 'w') as logfile: 333 | logwriter = csv.writer(logfile, delimiter='\t') 334 | logwriter.writerow([ 335 | 'Epoch', 'Train loss', 'Test loss', 336 | 'Test acc' 337 | ]) 338 | 339 | for epoch in range(start_epoch, args.epoch): 340 | mix_inputs_all, mix_targets_all, lams = generate_sample(trainloader) 341 | train_loss, _ = train( 342 | net, optimizer, mix_inputs_all, mix_targets_all, lams, epoch) 343 | test_loss, test_acc1, = test( 344 | net, optimizer, testloader, epoch, start_epoch) 345 | adjust_learning_rate(optimizer, epoch) 346 | with open(logname, 'a') as logfile: 347 | logwriter = csv.writer(logfile, delimiter='\t') 348 | logwriter.writerow( 349 | [epoch, train_loss, test_loss, test_acc1]) 350 | 351 | 352 | if __name__ == '__main__': 353 | main() 354 | -------------------------------------------------------------------------------- /train_cross.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 -u 2 | 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import csv 7 | import os 8 | 9 | import numpy as np 10 | 11 | import torch 12 | from torch.autograd import Variable 13 | import torch.backends.cudnn as cudnn 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | import torchvision.transforms as transforms 17 | import torchvision.datasets as datasets 18 | import torch.nn.functional as F 19 | import random 20 | 21 | import models 22 | from utils import progress_bar, chunks, save_fig 23 | 24 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 25 | 26 | parser = argparse.ArgumentParser( 27 | description='PyTorch InstaHide Training, CIFAR-10') 28 | 29 | # Training configurations 30 | parser.add_argument('--model', 31 | default="ResNet18", 32 | type=str, 33 | help='model type (default: ResNet18)') 34 | parser.add_argument('--data', default='cifar10', type=str, 35 | help='dataset') 36 | parser.add_argument('--nclass', default=10, type=int, 37 | help='number of classes') 38 | 39 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 40 | parser.add_argument('--batch-size', default=128, type=int, help='batch size') 41 | parser.add_argument('--epoch', 42 | default=200, 43 | type=int, 44 | help='total epochs to run') 45 | parser.add_argument('--no-augment', 46 | dest='augment', 47 | action='store_false', 48 | help='use standard augmentation (default: True)') 49 | parser.add_argument('--decay', default=1e-4, type=float, help='weight decay') 50 | 51 | # Saving configurations 52 | parser.add_argument('--name', default='cross', type=str, help='name of run') 53 | parser.add_argument('--seed', default=0, type=int, help='random seed') 54 | parser.add_argument('--resume', 55 | '-r', 56 | action='store_true', 57 | help='resume from checkpoint') 58 | parser.add_argument('--help_dir', 59 | default='./data/imagenet_filter_40', type=str) 60 | 61 | # InstaHide configurations 62 | parser.add_argument('--klam', default=4, type=int, help='number of lambdas') 63 | parser.add_argument('--mode', default='instahide', 64 | type=str, help='InsatHide or Mixup') 65 | parser.add_argument('--pair', action='store_true') 66 | parser.add_argument('--upper', default=0.65, type=float, help='the upper bound of any coefficient') 67 | parser.add_argument('--dom', default=0.3, type=float, help='the lower bound of the sum of coefficients of two private images') 68 | 69 | 70 | args = parser.parse_args() 71 | use_cuda = torch.cuda.is_available() 72 | device = torch.device("cuda" if use_cuda else "cpu") 73 | 74 | criterion = nn.CrossEntropyLoss() 75 | best_acc = 0 # best test accuracy 76 | 77 | ## --------------- Functions for train & eval --------------- ## 78 | 79 | 80 | def label_to_onehot(target, num_classes=args.nclass): 81 | '''Returns one-hot embeddings of scaler labels''' 82 | target = torch.unsqueeze(target, 1) 83 | onehot_target = torch.zeros(target.size( 84 | 0), num_classes, device=target.device) 85 | onehot_target.scatter_(1, target, 1) 86 | return onehot_target 87 | 88 | 89 | def cross_entropy_for_onehot(pred, target): 90 | return torch.mean(torch.sum(- target * F.log_softmax(pred, dim=-1), 1)) 91 | 92 | 93 | def mixup_criterion(pred, ys, lam_batch, num_class=args.nclass): 94 | '''Returns mixup loss''' 95 | if args.pair: 96 | inside_cnt = 2 97 | else: 98 | inside_cnt = (args.klam+1)//2 99 | ys_onehot = [label_to_onehot(y, num_classes=num_class) for y in ys] 100 | mixy = vec_mul_ten(lam_batch[:, 0], ys_onehot[0]) 101 | # for i in range(1, args.klam): 102 | for i in range(1, inside_cnt): 103 | mixy += vec_mul_ten(lam_batch[:, i], ys_onehot[i]) 104 | l = cross_entropy_for_onehot(pred, mixy) 105 | return l 106 | 107 | 108 | def vec_mul_ten(vec, tensor): 109 | size = list(tensor.size()) 110 | size[0] = -1 111 | size_rs = [1 for i in range(len(size))] 112 | size_rs[0] = -1 113 | vec = vec.reshape(size_rs).expand(size) 114 | res = vec * tensor 115 | return res 116 | 117 | 118 | def mixup_data(x, y, x_help, use_cuda=True): 119 | '''Returns mixed inputs, lists of targets, and lambdas''' 120 | lams = np.random.normal(0, 1, size=(x.size()[0], args.klam)) 121 | for i in range(x.size()[0]): 122 | lams[i] = np.abs(lams[i]) / np.sum(np.abs(lams[i])) 123 | if args.klam > 1: 124 | while lams[i].max() > args.upper or (lams[i][0] + lams[i][1]) < args.dom: # upper bounds a single lambda + lower bounds the sum of lambdas for private samples 125 | lams[i] = np.random.normal(0, 1, size=(1, args.klam)) 126 | lams[i] = np.abs(lams[i]) / np.sum(np.abs(lams[i])) 127 | 128 | lams = torch.from_numpy(lams).float().to(device) 129 | 130 | mixed_x = vec_mul_ten(lams[:, 0], x) 131 | ys = [y] 132 | 133 | if args.pair: 134 | inside_cnt = 2 135 | else: 136 | inside_cnt = (args.klam + 1)//2 137 | 138 | for i in range(1, args.klam): 139 | batch_size = x.size()[0] 140 | index = torch.randperm(batch_size).to(device) 141 | if i < inside_cnt: 142 | # mix private samples 143 | mixed_x += vec_mul_ten(lams[:, i], x[index, :]) 144 | else: 145 | # mix public samples 146 | mixed_x += vec_mul_ten(lams[:, i], x_help[index, :]) 147 | ys.append(y[index]) # Only keep the labels for private samples 148 | 149 | if args.mode == 'instahide': 150 | sign = torch.randint(2, size=list(x.shape), device=device) * 2.0 - 1 151 | mixed_x *= sign.float().to(device) 152 | return mixed_x, ys, lams 153 | 154 | 155 | def generate_sample(trainloader, inputs_help): 156 | assert len(trainloader) == 1 # Load all training data once 157 | inputs_help = inputs_help[torch.randperm(inputs_help.size()[0])] # Important! Permute the public dataset as a quick fix to issue #2 TODO: improve this 158 | for _, (inputs, targets) in enumerate(trainloader): 159 | if use_cuda: 160 | inputs, targets = inputs.cuda(), targets.cuda() 161 | mix_inputs, mix_targets, lams = mixup_data( 162 | inputs, targets.float(), inputs_help, use_cuda) 163 | return (mix_inputs, mix_targets, lams) 164 | 165 | 166 | def train(net, optimizer, inputs_all, mix_targets_all, lams, epoch): 167 | print('\nEpoch: %d' % epoch) 168 | net.train() 169 | train_loss, correct, total = 0, 0, 0 170 | 171 | seq = random.sample(range(len(inputs_all)), len(inputs_all)) 172 | bl = list(chunks(seq, args.batch_size)) 173 | 174 | for batch_idx in range(len(bl)): 175 | b = bl[batch_idx] 176 | inputs = torch.stack([inputs_all[i] for i in b]) 177 | if args.mode == 'instahide' or args.mode == 'mixup': 178 | lam_batch = torch.stack([lams[i] for i in b]) 179 | 180 | mix_targets = [] 181 | for ik in range(args.klam): 182 | mix_targets.append( 183 | torch.stack( 184 | [mix_targets_all[ik][ib].long().to(device) for ib in b])) 185 | targets_var = [Variable(mix_targets[ik]) for ik in range(args.klam)] 186 | 187 | inputs = Variable(inputs) 188 | outputs = net(inputs) 189 | loss = mixup_criterion(outputs, targets_var, lam_batch) 190 | train_loss += loss.data.item() 191 | total += args.batch_size 192 | optimizer.zero_grad() 193 | loss.backward() 194 | optimizer.step() 195 | 196 | progress_bar(batch_idx, len(inputs_all)/args.batch_size+1, 197 | 'Loss: %.3f' % (train_loss / (batch_idx + 1))) 198 | return (train_loss / batch_idx, 100. * correct / total) 199 | 200 | 201 | def test(net, optimizer, testloader, epoch, start_epoch): 202 | global best_acc 203 | net.eval() 204 | test_loss, correct_1, correct_5, total = 0, 0, 0, 0 205 | with torch.no_grad(): 206 | for batch_idx, (inputs, targets) in enumerate(testloader): 207 | if use_cuda: 208 | inputs, targets = inputs.cuda(), targets.cuda() 209 | inputs, targets = Variable(inputs), Variable(targets) 210 | outputs = net(inputs) 211 | loss = criterion(outputs, targets) 212 | 213 | test_loss += loss.data.item() 214 | _, pred = outputs.topk(5, 1, largest=True, sorted=True) 215 | total += targets.size(0) 216 | correct = pred.eq(targets.view(targets.size(0), - 217 | 1).expand_as(pred)).float().cpu() 218 | correct_1 += correct[:, :1].sum() 219 | correct_5 += correct[:, :5].sum() 220 | 221 | progress_bar( 222 | batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)' % 223 | (test_loss / 224 | (batch_idx + 1), 100. * correct_1 / total, correct_1, total)) 225 | 226 | acc = 100. * correct_1 / total 227 | if epoch == start_epoch + args.epoch - 1 or acc > best_acc: 228 | save_checkpoint(net, acc, epoch) 229 | if acc > best_acc: 230 | best_acc = acc 231 | return (test_loss / batch_idx, 100. * correct_1 / total) 232 | 233 | 234 | def save_checkpoint(net, acc, epoch): 235 | """ Save checkpoints. """ 236 | print('Saving..') 237 | state = { 238 | 'net': net, 239 | 'acc': acc, 240 | 'epoch': epoch, 241 | 'rng_state': torch.get_rng_state() 242 | } 243 | if not os.path.isdir('checkpoint'): 244 | os.mkdir('checkpoint') 245 | ckptname = os.path.join( 246 | './checkpoint/', f'{args.model}_{args.data}_{args.mode}_{args.klam}_{args.name}_{args.seed}.t7') 247 | torch.save(state, ckptname) 248 | 249 | 250 | def adjust_learning_rate(optimizer, epoch): 251 | """ Decrease learning rate at certain epochs. """ 252 | lr = args.lr 253 | if args.data == 'cifar10': 254 | if epoch >= 100: 255 | lr /= 10 256 | if epoch >= 150: 257 | lr /= 10 258 | for param_group in optimizer.param_groups: 259 | param_group['lr'] = lr 260 | 261 | 262 | def main(): 263 | global best_acc 264 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 265 | 266 | if args.seed != 0: 267 | torch.manual_seed(args.seed) 268 | np.random.seed(args.seed) 269 | 270 | print('==> Number of lambdas: %g' % args.klam) 271 | 272 | ## --------------- Prepare data --------------- ## 273 | print('==> Preparing data..') 274 | 275 | cifar_normalize = transforms.Normalize( 276 | (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 277 | 278 | transform_imagenet = transforms.Compose([ 279 | transforms.Resize(40), 280 | transforms.RandomCrop(32), 281 | transforms.ToTensor(), 282 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 283 | ]) 284 | 285 | if args.augment: 286 | transform_cifar_train = transforms.Compose([ 287 | transforms.RandomCrop(32, padding=4), 288 | transforms.RandomHorizontalFlip(), 289 | transforms.ToTensor(), 290 | cifar_normalize 291 | ]) 292 | else: 293 | transform_cifar_train = transforms.Compose([ 294 | transforms.ToTensor(), 295 | cifar_normalize 296 | ]) 297 | 298 | transform_cifar_test = transforms.Compose([ 299 | transforms.ToTensor(), 300 | cifar_normalize 301 | ]) 302 | 303 | if args.data == 'cifar10': 304 | trainset = datasets.CIFAR10(root='./data', 305 | train=True, 306 | download=True, 307 | transform=transform_cifar_train) 308 | testset = datasets.CIFAR10(root='./data', 309 | train=False, 310 | download=True, 311 | transform=transform_cifar_test) 312 | trainset_help = datasets.ImageFolder( 313 | args.help_dir, transform=transform_imagenet) 314 | num_class = 10 315 | # You can add your own dataloader and preprocessor here. 316 | 317 | trainloader = torch.utils.data.DataLoader(trainset, 318 | batch_size=len(trainset), 319 | shuffle=True, 320 | num_workers=8) 321 | 322 | # TODO: a more memory-efficient implementation: index offline and only load public samples for encryption per epoch 323 | trainloader_help = torch.utils.data.DataLoader(trainset_help, 324 | batch_size=len( 325 | trainset_help), 326 | shuffle=True, 327 | num_workers=8) 328 | 329 | testloader = torch.utils.data.DataLoader(testset, 330 | batch_size=args.batch_size, 331 | shuffle=False, 332 | num_workers=8) 333 | 334 | ## --------------- Create the model --------------- ## 335 | if args.resume: 336 | # Load checkpoint. 337 | print('==> Resuming from checkpoint..') 338 | assert os.path.isdir( 339 | 'checkpoint'), 'Error: no checkpoint directory found!' 340 | checkpoint = torch.load('./checkpoint/' + args.data + '_' + 341 | args.name + 'ckpt.t7') 342 | net = checkpoint['net'] 343 | best_acc = checkpoint['acc'] 344 | start_epoch = checkpoint['epoch'] + 1 345 | rng_state = checkpoint['rng_state'] 346 | torch.set_rng_state(rng_state) 347 | else: 348 | print('==> Building model..') 349 | net = models.__dict__[args.model](num_classes=num_class) 350 | 351 | if not os.path.isdir('results'): 352 | os.mkdir('results') 353 | logname = f'results/log_{args.model}_{args.data}_{args.mode}_{args.klam}_{args.name}_{args.seed}.csv' 354 | 355 | if use_cuda: 356 | net.cuda() 357 | net = torch.nn.DataParallel(net) 358 | cudnn.benchmark = True 359 | print('==> Using CUDA..') 360 | 361 | optimizer = optim.SGD(net.parameters(), 362 | lr=args.lr, 363 | momentum=0.9, 364 | weight_decay=args.decay) 365 | 366 | ## --------------- Train and Eval --------------- ## 367 | if not os.path.exists(logname): 368 | with open(logname, 'w') as logfile: 369 | logwriter = csv.writer(logfile, delimiter='\t') 370 | logwriter.writerow([ 371 | 'Epoch', 'Train loss', 'Test loss', 372 | 'Test acc' 373 | ]) 374 | 375 | 376 | for _, (inputs_help, targets) in enumerate(trainloader_help): 377 | if use_cuda: 378 | inputs_help = inputs_help.cuda() 379 | 380 | for epoch in range(start_epoch, args.epoch): 381 | mix_inputs_all, mix_targets_all, lams = generate_sample( 382 | trainloader, inputs_help) 383 | train_loss, _ = train( 384 | net, optimizer, mix_inputs_all, mix_targets_all, lams, epoch) 385 | test_loss, test_acc1, = test( 386 | net, optimizer, testloader, epoch, start_epoch) 387 | adjust_learning_rate(optimizer, epoch) 388 | with open(logname, 'a') as logfile: 389 | logwriter = csv.writer(logfile, delimiter='\t') 390 | logwriter.writerow( 391 | [epoch, train_loss, test_loss, test_acc1]) 392 | 393 | 394 | if __name__ == '__main__': 395 | main() 396 | --------------------------------------------------------------------------------