├── figure ├── Table1.jpg ├── Table2.jpg └── figure.jpg ├── License.txt ├── THIRD PARTY OPEN SOURCE SOFTWARE NOTICE.txt ├── README.md ├── lenet.py ├── resnet.py ├── teacher-train.py └── DAFL-train.py /figure/Table1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autogyro/DAFL/HEAD/figure/Table1.jpg -------------------------------------------------------------------------------- /figure/Table2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autogyro/DAFL/HEAD/figure/Table2.jpg -------------------------------------------------------------------------------- /figure/figure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autogyro/DAFL/HEAD/figure/figure.jpg -------------------------------------------------------------------------------- /License.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019. Huawei Technologies Co., Ltd. 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | 8 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | 10 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 11 | 12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | 14 | -------------------------------------------------------------------------------- /THIRD PARTY OPEN SOURCE SOFTWARE NOTICE.txt: -------------------------------------------------------------------------------- 1 | Please note we provide an open source software notice for the third party open source software along with this software and/or this software component contributed by Huawei (in the following just “this SOFTWARE”). The open source software licenses are granted by the respective right holders. 2 | 3 | Warranty Disclaimer 4 | THE OPEN SOURCE SOFTWARE IN THIS SOFTWARE IS DISTRIBUTED IN THE HOPE THAT IT WILL BE USEFUL, BUT WITHOUT ANY WARRANTY, WITHOUT EVEN THE IMPLIED WARRANTY OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. SEE THE APPLICABLE LICENSES FOR MORE DETAILS. 5 | 6 | Copyright Notice and License Texts 7 | Software: pytorch-cifar 8 | Copyright notice: 9 | Copyright (c) 2017 10 | 11 | License: MIT License 12 | 13 | Copyright (c) 2017 liukuang 14 | 15 | Permission is hereby granted, free of charge, to any person obtaining a copy 16 | of this software and associated documentation files (the "Software"), to deal 17 | in the Software without restriction, including without limitation the rights 18 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 19 | copies of the Software, and to permit persons to whom the Software is 20 | furnished to do so, subject to the following conditions: 21 | 22 | The above copyright notice and this permission notice shall be included in all 23 | copies or substantial portions of the Software. 24 | 25 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 26 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 27 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 28 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 29 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 30 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 31 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DAFL: Data-Free Learning of Student Networks 2 | This code is the Pytorch implementation of ICCV 2019 paper [DAFL: Data-Free Learning of Student Networks](https://arxiv.org/pdf/1904.01186.pdf) 3 | 4 | We propose a novel framework for training efficient deep neural networks by exploiting generative adversarial networks (GANs). To be specific, the pre-trained teacher networks are regarded as a fixed discriminator and the generator is utilized for derivating training samples which can obtain the maximum response on the discriminator. Then, an efficient network with smaller model size and computational complexity is trained using the generated data and the teacher network, simultaneously. 5 | 6 |

7 | 8 |

9 | 10 | 11 | ## Requirements 12 | - python 3 13 | - pytorch >= 1.0.0 14 | - torchvision 15 | 16 | ## Run the demo 17 | ```shell 18 | python teacher-train.py 19 | ``` 20 | First, you should train a teacher network. 21 | ```shell 22 | python DAFL-train.py 23 | ``` 24 | Then, you can use the DAFL to train a student network without training data on the MNIST dataset. 25 | 26 | To run DAFL on the CIFAR-10 dataset 27 | ```shell 28 | python teacher-train.py --dataset cifar10 29 | python DAFL-train.py --dataset cifar10 --channels 3 --n_epochs 2000 --batch_size 1024 --lr_G 0.02 --lr_S 0.1 --latent_dim 1000 30 | ``` 31 | 32 | To run DAFL on the CIFAR-100 dataset 33 | ```shell 34 | python teacher-train.py --dataset cifar100 35 | python DAFL-train.py --dataset cifar100 --channels 3 --n_epochs 2000 --batch_size 1024 --lr_G 0.02 --lr_S 0.1 --latent_dim 1000 --oh 0.5 --ie 20 36 | ``` 37 | 38 | ## Results 39 | 40 |

41 | 42 | 43 |

44 | 45 | 46 | ## Citation 47 | @inproceedings{DAFL, 48 | title={DAFL: Data-Free Learning of Student Networks}, 49 | author={Chen, Hanting and Wang, Yunhe and Xu, Chang and Yang, Zhaohui and Liu, Chuanjian and Shi, Boxin and Xu, Chunjing and Xu, Chao and Tian, Qi}, 50 | booktitle={ICCV}, 51 | year={2019} 52 | } 53 | 54 | ## Contributing 55 | We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion. 56 | 57 | If you plan to contribute new features, utility functions or extensions to the core, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR, because we might be taking the core in a different direction than you might be aware of. 58 | -------------------------------------------------------------------------------- /lenet.py: -------------------------------------------------------------------------------- 1 | #Copyright (C) 2019. Huawei Technologies Co., Ltd. All rights reserved. 2 | 3 | #This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License. 4 | 5 | #This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details. 6 | 7 | import torch.nn as nn 8 | 9 | 10 | class LeNet5(nn.Module): 11 | 12 | def __init__(self): 13 | super(LeNet5, self).__init__() 14 | 15 | self.conv1 = nn.Conv2d(1, 6, kernel_size=(5, 5)) 16 | self.relu1 = nn.ReLU() 17 | self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=2) 18 | self.conv2 = nn.Conv2d(6, 16, kernel_size=(5, 5)) 19 | self.relu2 = nn.ReLU() 20 | self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=2) 21 | self.conv3 = nn.Conv2d(16, 120, kernel_size=(5, 5)) 22 | self.relu3 = nn.ReLU() 23 | self.fc1 = nn.Linear(120, 84) 24 | self.relu4 = nn.ReLU() 25 | self.fc2 = nn.Linear(84, 10) 26 | 27 | def forward(self, img, out_feature=False): 28 | output = self.conv1(img) 29 | output = self.relu1(output) 30 | output = self.maxpool1(output) 31 | output = self.conv2(output) 32 | output = self.relu2(output) 33 | output = self.maxpool2(output) 34 | output = self.conv3(output) 35 | output = self.relu3(output) 36 | feature = output.view(-1, 120) 37 | output = self.fc1(feature) 38 | output = self.relu4(output) 39 | output = self.fc2(output) 40 | if out_feature == False: 41 | return output 42 | else: 43 | return output,feature 44 | 45 | 46 | class LeNet5Half(nn.Module): 47 | 48 | def __init__(self): 49 | super(LeNet5Half, self).__init__() 50 | 51 | self.conv1 = nn.Conv2d(1, 3, kernel_size=(5, 5)) 52 | self.relu1 = nn.ReLU() 53 | self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=2) 54 | self.conv2 = nn.Conv2d(3, 8, kernel_size=(5, 5)) 55 | self.relu2 = nn.ReLU() 56 | self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=2) 57 | self.conv3 = nn.Conv2d(8, 60, kernel_size=(5, 5)) 58 | self.relu3 = nn.ReLU() 59 | self.fc1 = nn.Linear(60, 42) 60 | self.relu4 = nn.ReLU() 61 | self.fc2 = nn.Linear(42, 10) 62 | 63 | def forward(self, img, out_feature=False): 64 | output = self.conv1(img) 65 | output = self.relu1(output) 66 | output = self.maxpool1(output) 67 | output = self.conv2(output) 68 | output = self.relu2(output) 69 | output = self.maxpool2(output) 70 | output = self.conv3(output) 71 | output = self.relu3(output) 72 | feature = output.view(-1, 60) 73 | output = self.fc1(feature) 74 | output = self.relu4(output) 75 | output = self.fc2(output) 76 | if out_feature == False: 77 | return output 78 | else: 79 | return output,feature -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | # 2019.07.24-Changed output of forward function 2 | # Huawei Technologies Co., Ltd. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class BasicBlock(nn.Module): 10 | expansion = 1 11 | 12 | def __init__(self, in_planes, planes, stride=1): 13 | super(BasicBlock, self).__init__() 14 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 15 | self.bn1 = nn.BatchNorm2d(planes) 16 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 17 | self.bn2 = nn.BatchNorm2d(planes) 18 | 19 | self.shortcut = nn.Sequential() 20 | if stride != 1 or in_planes != self.expansion*planes: 21 | self.shortcut = nn.Sequential( 22 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 23 | nn.BatchNorm2d(self.expansion*planes) 24 | ) 25 | 26 | def forward(self, x): 27 | out = F.relu(self.bn1(self.conv1(x))) 28 | out = self.bn2(self.conv2(out)) 29 | out += self.shortcut(x) 30 | out = F.relu(out) 31 | return out 32 | 33 | 34 | class Bottleneck(nn.Module): 35 | expansion = 4 36 | 37 | def __init__(self, in_planes, planes, stride=1): 38 | super(Bottleneck, self).__init__() 39 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 40 | self.bn1 = nn.BatchNorm2d(planes) 41 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 42 | self.bn2 = nn.BatchNorm2d(planes) 43 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 44 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 45 | 46 | self.shortcut = nn.Sequential() 47 | if stride != 1 or in_planes != self.expansion*planes: 48 | self.shortcut = nn.Sequential( 49 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 50 | nn.BatchNorm2d(self.expansion*planes) 51 | ) 52 | 53 | def forward(self, x): 54 | out = F.relu(self.bn1(self.conv1(x))) 55 | out = F.relu(self.bn2(self.conv2(out))) 56 | out = self.bn3(self.conv3(out)) 57 | out += self.shortcut(x) 58 | out = F.relu(out) 59 | return out 60 | 61 | 62 | class ResNet(nn.Module): 63 | def __init__(self, block, num_blocks, num_classes=10): 64 | super(ResNet, self).__init__() 65 | self.in_planes = 64 66 | 67 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 68 | self.bn1 = nn.BatchNorm2d(64) 69 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 70 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 71 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 72 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 73 | self.linear = nn.Linear(512*block.expansion, num_classes) 74 | 75 | def _make_layer(self, block, planes, num_blocks, stride): 76 | strides = [stride] + [1]*(num_blocks-1) 77 | layers = [] 78 | for stride in strides: 79 | layers.append(block(self.in_planes, planes, stride)) 80 | self.in_planes = planes * block.expansion 81 | return nn.Sequential(*layers) 82 | 83 | def forward(self, x, out_feature=False): 84 | out = F.relu(self.bn1(self.conv1(x))) 85 | out = self.layer1(out) 86 | out = self.layer2(out) 87 | out = self.layer3(out) 88 | out = self.layer4(out) 89 | out = F.avg_pool2d(out, 4) 90 | feature = out.view(out.size(0), -1) 91 | out = self.linear(feature) 92 | if out_feature == False: 93 | return out 94 | else: 95 | return out,feature 96 | 97 | 98 | def ResNet18(num_classes=10): 99 | return ResNet(BasicBlock, [2,2,2,2], num_classes) 100 | 101 | def ResNet34(num_classes=10): 102 | return ResNet(BasicBlock, [3,4,6,3], num_classes) 103 | 104 | def ResNet50(num_classes=10): 105 | return ResNet(Bottleneck, [3,4,6,3], num_classes) 106 | 107 | def ResNet101(num_classes=10): 108 | return ResNet(Bottleneck, [3,4,23,3], num_classes) 109 | 110 | def ResNet152(num_classes=10): 111 | return ResNet(Bottleneck, [3,8,36,3], num_classes) 112 | 113 | -------------------------------------------------------------------------------- /teacher-train.py: -------------------------------------------------------------------------------- 1 | #Copyright (C) 2019. Huawei Technologies Co., Ltd. All rights reserved. 2 | 3 | #This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License. 4 | 5 | #This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details. 6 | 7 | import os 8 | from lenet import LeNet5 9 | import resnet 10 | import torch 11 | from torch.autograd import Variable 12 | from torchvision.datasets.mnist import MNIST 13 | from torchvision.datasets import CIFAR10 14 | from torchvision.datasets import CIFAR100 15 | import torchvision.transforms as transforms 16 | from torch.utils.data import DataLoader 17 | import argparse 18 | 19 | parser = argparse.ArgumentParser(description='train-teacher-network') 20 | 21 | # Basic model parameters. 22 | parser.add_argument('--dataset', type=str, default='MNIST', choices=['MNIST','cifar10','cifar100']) 23 | parser.add_argument('--data', type=str, default='/cache/data/') 24 | parser.add_argument('--output_dir', type=str, default='/cache/models/') 25 | args = parser.parse_args() 26 | 27 | os.makedirs(args.output_dir, exist_ok=True) 28 | 29 | acc = 0 30 | acc_best = 0 31 | 32 | if args.dataset == 'MNIST': 33 | 34 | data_train = MNIST(args.data, 35 | transform=transforms.Compose([ 36 | transforms.Resize((32, 32)), 37 | transforms.ToTensor(), 38 | transforms.Normalize((0.1307,), (0.3081,)) 39 | ])) 40 | data_test = MNIST(args.data, 41 | train=False, 42 | transform=transforms.Compose([ 43 | transforms.Resize((32, 32)), 44 | transforms.ToTensor(), 45 | transforms.Normalize((0.1307,), (0.3081,)) 46 | ])) 47 | 48 | data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8) 49 | data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8) 50 | 51 | net = LeNet5().cuda() 52 | criterion = torch.nn.CrossEntropyLoss().cuda() 53 | optimizer = torch.optim.Adam(net.parameters(), lr=0.001) 54 | 55 | if args.dataset == 'cifar10': 56 | 57 | transform_train = transforms.Compose([ 58 | transforms.RandomCrop(32, padding=4), 59 | transforms.RandomHorizontalFlip(), 60 | transforms.ToTensor(), 61 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 62 | ]) 63 | 64 | transform_test = transforms.Compose([ 65 | transforms.ToTensor(), 66 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 67 | ]) 68 | 69 | data_train = CIFAR10(args.data, 70 | transform=transform_train) 71 | data_test = CIFAR10(args.data, 72 | train=False, 73 | transform=transform_test) 74 | 75 | data_train_loader = DataLoader(data_train, batch_size=128, shuffle=True, num_workers=8) 76 | data_test_loader = DataLoader(data_test, batch_size=100, num_workers=0) 77 | 78 | net = resnet.ResNet34().cuda() 79 | criterion = torch.nn.CrossEntropyLoss().cuda() 80 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) 81 | 82 | if args.dataset == 'cifar100': 83 | 84 | transform_train = transforms.Compose([ 85 | transforms.RandomCrop(32, padding=4), 86 | transforms.RandomHorizontalFlip(), 87 | transforms.ToTensor(), 88 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 89 | ]) 90 | 91 | transform_test = transforms.Compose([ 92 | transforms.ToTensor(), 93 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 94 | ]) 95 | 96 | data_train = CIFAR100(args.data, 97 | transform=transform_train) 98 | data_test = CIFAR100(args.data, 99 | train=False, 100 | transform=transform_test) 101 | 102 | data_train_loader = DataLoader(data_train, batch_size=128, shuffle=True, num_workers=0) 103 | data_test_loader = DataLoader(data_test, batch_size=128, num_workers=0) 104 | net = resnet.ResNet34(num_classes=100).cuda() 105 | criterion = torch.nn.CrossEntropyLoss().cuda() 106 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) 107 | 108 | 109 | def adjust_learning_rate(optimizer, epoch): 110 | """For resnet, the lr starts from 0.1, and is divided by 10 at 80 and 120 epochs""" 111 | if epoch < 80: 112 | lr = 0.1 113 | elif epoch < 120: 114 | lr = 0.01 115 | else: 116 | lr = 0.001 117 | for param_group in optimizer.param_groups: 118 | param_group['lr'] = lr 119 | 120 | def train(epoch): 121 | if args.dataset != 'MNIST': 122 | adjust_learning_rate(optimizer, epoch) 123 | global cur_batch_win 124 | net.train() 125 | loss_list, batch_list = [], [] 126 | for i, (images, labels) in enumerate(data_train_loader): 127 | images, labels = Variable(images).cuda(), Variable(labels).cuda() 128 | 129 | optimizer.zero_grad() 130 | 131 | output = net(images) 132 | 133 | loss = criterion(output, labels) 134 | 135 | loss_list.append(loss.data.item()) 136 | batch_list.append(i+1) 137 | 138 | if i == 1: 139 | print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.data.item())) 140 | 141 | loss.backward() 142 | optimizer.step() 143 | 144 | 145 | def test(): 146 | global acc, acc_best 147 | net.eval() 148 | total_correct = 0 149 | avg_loss = 0.0 150 | with torch.no_grad(): 151 | for i, (images, labels) in enumerate(data_test_loader): 152 | images, labels = Variable(images).cuda(), Variable(labels).cuda() 153 | output = net(images) 154 | avg_loss += criterion(output, labels).sum() 155 | pred = output.data.max(1)[1] 156 | total_correct += pred.eq(labels.data.view_as(pred)).sum() 157 | 158 | avg_loss /= len(data_test) 159 | acc = float(total_correct) / len(data_test) 160 | if acc_best < acc: 161 | acc_best = acc 162 | print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.data.item(), acc)) 163 | 164 | 165 | def train_and_test(epoch): 166 | train(epoch) 167 | test() 168 | 169 | 170 | def main(): 171 | if args.dataset == 'MNIST': 172 | epoch = 10 173 | else: 174 | epoch = 200 175 | for e in range(1, epoch): 176 | train_and_test(e) 177 | torch.save(net,args.output_dir + 'teacher') 178 | 179 | 180 | if __name__ == '__main__': 181 | main() -------------------------------------------------------------------------------- /DAFL-train.py: -------------------------------------------------------------------------------- 1 | #Copyright (C) 2019. Huawei Technologies Co., Ltd. All rights reserved. 2 | 3 | #This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License. 4 | 5 | #This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details. 6 | 7 | import argparse 8 | import os 9 | import numpy as np 10 | import math 11 | import sys 12 | import pdb 13 | 14 | import torchvision.transforms as transforms 15 | 16 | from torch.utils.data import DataLoader 17 | from torchvision import datasets 18 | from torch.autograd import Variable 19 | 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | import torch 23 | from torchvision.datasets.mnist import MNIST 24 | from lenet import LeNet5Half 25 | from torchvision.datasets import CIFAR10 26 | from torchvision.datasets import CIFAR100 27 | import resnet 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--dataset', type=str, default='MNIST', choices=['MNIST','cifar10','cifar100']) 31 | parser.add_argument('--data', type=str, default='/cache/data/') 32 | parser.add_argument('--teacher_dir', type=str, default='/cache/models/') 33 | parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training') 34 | parser.add_argument('--batch_size', type=int, default=512, help='size of the batches') 35 | parser.add_argument('--lr_G', type=float, default=0.2, help='learning rate') 36 | parser.add_argument('--lr_S', type=float, default=2e-3, help='learning rate') 37 | parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space') 38 | parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension') 39 | parser.add_argument('--channels', type=int, default=1, help='number of image channels') 40 | parser.add_argument('--oh', type=float, default=1, help='one hot loss') 41 | parser.add_argument('--ie', type=float, default=5, help='information entropy loss') 42 | parser.add_argument('--a', type=float, default=0.1, help='activation loss') 43 | parser.add_argument('--output_dir', type=str, default='/cache/models/') 44 | 45 | opt = parser.parse_args() 46 | 47 | img_shape = (opt.channels, opt.img_size, opt.img_size) 48 | 49 | cuda = True 50 | 51 | accr = 0 52 | accr_best = 0 53 | 54 | class Generator(nn.Module): 55 | def __init__(self): 56 | super(Generator, self).__init__() 57 | 58 | self.init_size = opt.img_size // 4 59 | self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128*self.init_size**2)) 60 | 61 | self.conv_blocks0 = nn.Sequential( 62 | nn.BatchNorm2d(128), 63 | ) 64 | self.conv_blocks1 = nn.Sequential( 65 | nn.Conv2d(128, 128, 3, stride=1, padding=1), 66 | nn.BatchNorm2d(128, 0.8), 67 | nn.LeakyReLU(0.2, inplace=True), 68 | ) 69 | self.conv_blocks2 = nn.Sequential( 70 | nn.Conv2d(128, 64, 3, stride=1, padding=1), 71 | nn.BatchNorm2d(64, 0.8), 72 | nn.LeakyReLU(0.2, inplace=True), 73 | nn.Conv2d(64, opt.channels, 3, stride=1, padding=1), 74 | nn.Tanh(), 75 | nn.BatchNorm2d(opt.channels, affine=False) 76 | ) 77 | 78 | def forward(self, z): 79 | out = self.l1(z) 80 | out = out.view(out.shape[0], 128, self.init_size, self.init_size) 81 | img = self.conv_blocks0(out) 82 | img = nn.functional.interpolate(img,scale_factor=2) 83 | img = self.conv_blocks1(img) 84 | img = nn.functional.interpolate(img,scale_factor=2) 85 | img = self.conv_blocks2(img) 86 | return img 87 | 88 | generator = Generator().cuda() 89 | 90 | teacher = torch.load(opt.teacher_dir + 'teacher').cuda() 91 | teacher.eval() 92 | criterion = torch.nn.CrossEntropyLoss().cuda() 93 | 94 | teacher = nn.DataParallel(teacher) 95 | generator = nn.DataParallel(generator) 96 | 97 | def kdloss(y, teacher_scores): 98 | p = F.log_softmax(y, dim=1) 99 | q = F.softmax(teacher_scores, dim=1) 100 | l_kl = F.kl_div(p, q, size_average=False) / y.shape[0] 101 | return l_kl 102 | 103 | if opt.dataset == 'MNIST': 104 | # Configure data loader 105 | net = LeNet5Half().cuda() 106 | net = nn.DataParallel(net) 107 | data_test = MNIST(opt.data, 108 | train=False, 109 | transform=transforms.Compose([ 110 | transforms.Resize((32, 32)), 111 | transforms.ToTensor(), 112 | transforms.Normalize((0.1307,), (0.3081,)) 113 | ])) 114 | data_test_loader = DataLoader(data_test, batch_size=64, num_workers=1, shuffle=False) 115 | 116 | # Optimizers 117 | optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_G) 118 | optimizer_S = torch.optim.Adam(net.parameters(), lr=opt.lr_S) 119 | 120 | if opt.dataset != 'MNIST': 121 | transform_test = transforms.Compose([ 122 | transforms.ToTensor(), 123 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 124 | ]) 125 | if opt.dataset == 'cifar10': 126 | net = resnet.ResNet18().cuda() 127 | net = nn.DataParallel(net) 128 | data_test = CIFAR10(opt.data, 129 | train=False, 130 | transform=transform_test) 131 | if opt.dataset == 'cifar100': 132 | net = resnet.ResNet18(num_classes=100).cuda() 133 | net = nn.DataParallel(net) 134 | data_test = CIFAR100(opt.data, 135 | train=False, 136 | transform=transform_test) 137 | data_test_loader = DataLoader(data_test, batch_size=opt.batch_size, num_workers=0) 138 | 139 | # Optimizers 140 | optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_G) 141 | 142 | optimizer_S = torch.optim.SGD(net.parameters(), lr=opt.lr_S, momentum=0.9, weight_decay=5e-4) 143 | 144 | 145 | def adjust_learning_rate(optimizer, epoch, learing_rate): 146 | if epoch < 800: 147 | lr = learing_rate 148 | elif epoch < 1600: 149 | lr = 0.1*learing_rate 150 | else: 151 | lr = 0.01*learing_rate 152 | for param_group in optimizer.param_groups: 153 | param_group['lr'] = lr 154 | 155 | 156 | # ---------- 157 | # Training 158 | # ---------- 159 | 160 | batches_done = 0 161 | for epoch in range(opt.n_epochs): 162 | 163 | total_correct = 0 164 | avg_loss = 0.0 165 | if opt.dataset != 'MNIST': 166 | adjust_learning_rate(optimizer_S, epoch, opt.lr_S) 167 | 168 | for i in range(120): 169 | net.train() 170 | z = Variable(torch.randn(opt.batch_size, opt.latent_dim)).cuda() 171 | optimizer_G.zero_grad() 172 | optimizer_S.zero_grad() 173 | gen_imgs = generator(z) 174 | outputs_T, features_T = teacher(gen_imgs, out_feature=True) 175 | pred = outputs_T.data.max(1)[1] 176 | loss_activation = -features_T.abs().mean() 177 | loss_one_hot = criterion(outputs_T,pred) 178 | softmax_o_T = torch.nn.functional.softmax(outputs_T, dim = 1).mean(dim = 0) 179 | loss_information_entropy = (softmax_o_T * torch.log10(softmax_o_T)).sum() 180 | loss = loss_one_hot * opt.oh + loss_information_entropy * opt.ie + loss_activation * opt.a 181 | loss_kd = kdloss(net(gen_imgs.detach()), outputs_T.detach()) 182 | loss += loss_kd 183 | loss.backward() 184 | optimizer_G.step() 185 | optimizer_S.step() 186 | if i == 1: 187 | print ("[Epoch %d/%d] [loss_oh: %f] [loss_ie: %f] [loss_a: %f] [loss_kd: %f]" % (epoch, opt.n_epochs,loss_one_hot.item(), loss_information_entropy.item(), loss_activation.item(), loss_kd.item())) 188 | 189 | with torch.no_grad(): 190 | for i, (images, labels) in enumerate(data_test_loader): 191 | images = images.cuda() 192 | labels = labels.cuda() 193 | net.eval() 194 | output = net(images) 195 | avg_loss += criterion(output, labels).sum() 196 | pred = output.data.max(1)[1] 197 | total_correct += pred.eq(labels.data.view_as(pred)).sum() 198 | 199 | avg_loss /= len(data_test) 200 | print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.data.item(), float(total_correct) / len(data_test))) 201 | accr = round(float(total_correct) / len(data_test), 4) 202 | if accr > accr_best: 203 | torch.save(net,opt.output_dir + 'student') 204 | accr_best = accr 205 | --------------------------------------------------------------------------------