├── figure
├── Table1.jpg
├── Table2.jpg
└── figure.jpg
├── License.txt
├── THIRD PARTY OPEN SOURCE SOFTWARE NOTICE.txt
├── README.md
├── lenet.py
├── resnet.py
├── teacher-train.py
└── DAFL-train.py
/figure/Table1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autogyro/DAFL/HEAD/figure/Table1.jpg
--------------------------------------------------------------------------------
/figure/Table2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autogyro/DAFL/HEAD/figure/Table2.jpg
--------------------------------------------------------------------------------
/figure/figure.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autogyro/DAFL/HEAD/figure/figure.jpg
--------------------------------------------------------------------------------
/License.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2019. Huawei Technologies Co., Ltd.
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5 |
6 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7 |
8 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
9 |
10 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11 |
12 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
13 |
14 |
--------------------------------------------------------------------------------
/THIRD PARTY OPEN SOURCE SOFTWARE NOTICE.txt:
--------------------------------------------------------------------------------
1 | Please note we provide an open source software notice for the third party open source software along with this software and/or this software component contributed by Huawei (in the following just “this SOFTWARE”). The open source software licenses are granted by the respective right holders.
2 |
3 | Warranty Disclaimer
4 | THE OPEN SOURCE SOFTWARE IN THIS SOFTWARE IS DISTRIBUTED IN THE HOPE THAT IT WILL BE USEFUL, BUT WITHOUT ANY WARRANTY, WITHOUT EVEN THE IMPLIED WARRANTY OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. SEE THE APPLICABLE LICENSES FOR MORE DETAILS.
5 |
6 | Copyright Notice and License Texts
7 | Software: pytorch-cifar
8 | Copyright notice:
9 | Copyright (c) 2017
10 |
11 | License: MIT License
12 |
13 | Copyright (c) 2017 liukuang
14 |
15 | Permission is hereby granted, free of charge, to any person obtaining a copy
16 | of this software and associated documentation files (the "Software"), to deal
17 | in the Software without restriction, including without limitation the rights
18 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
19 | copies of the Software, and to permit persons to whom the Software is
20 | furnished to do so, subject to the following conditions:
21 |
22 | The above copyright notice and this permission notice shall be included in all
23 | copies or substantial portions of the Software.
24 |
25 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
26 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
27 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
28 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
30 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DAFL: Data-Free Learning of Student Networks
2 | This code is the Pytorch implementation of ICCV 2019 paper [DAFL: Data-Free Learning of Student Networks](https://arxiv.org/pdf/1904.01186.pdf)
3 |
4 | We propose a novel framework for training efficient deep neural networks by exploiting generative adversarial networks (GANs). To be specific, the pre-trained teacher networks are regarded as a fixed discriminator and the generator is utilized for derivating training samples which can obtain the maximum response on the discriminator. Then, an efficient network with smaller model size and computational complexity is trained using the generated data and the teacher network, simultaneously.
5 |
6 |
7 |
8 |
9 |
10 |
11 | ## Requirements
12 | - python 3
13 | - pytorch >= 1.0.0
14 | - torchvision
15 |
16 | ## Run the demo
17 | ```shell
18 | python teacher-train.py
19 | ```
20 | First, you should train a teacher network.
21 | ```shell
22 | python DAFL-train.py
23 | ```
24 | Then, you can use the DAFL to train a student network without training data on the MNIST dataset.
25 |
26 | To run DAFL on the CIFAR-10 dataset
27 | ```shell
28 | python teacher-train.py --dataset cifar10
29 | python DAFL-train.py --dataset cifar10 --channels 3 --n_epochs 2000 --batch_size 1024 --lr_G 0.02 --lr_S 0.1 --latent_dim 1000
30 | ```
31 |
32 | To run DAFL on the CIFAR-100 dataset
33 | ```shell
34 | python teacher-train.py --dataset cifar100
35 | python DAFL-train.py --dataset cifar100 --channels 3 --n_epochs 2000 --batch_size 1024 --lr_G 0.02 --lr_S 0.1 --latent_dim 1000 --oh 0.5 --ie 20
36 | ```
37 |
38 | ## Results
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 | ## Citation
47 | @inproceedings{DAFL,
48 | title={DAFL: Data-Free Learning of Student Networks},
49 | author={Chen, Hanting and Wang, Yunhe and Xu, Chang and Yang, Zhaohui and Liu, Chuanjian and Shi, Boxin and Xu, Chunjing and Xu, Chao and Tian, Qi},
50 | booktitle={ICCV},
51 | year={2019}
52 | }
53 |
54 | ## Contributing
55 | We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion.
56 |
57 | If you plan to contribute new features, utility functions or extensions to the core, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR, because we might be taking the core in a different direction than you might be aware of.
58 |
--------------------------------------------------------------------------------
/lenet.py:
--------------------------------------------------------------------------------
1 | #Copyright (C) 2019. Huawei Technologies Co., Ltd. All rights reserved.
2 |
3 | #This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License.
4 |
5 | #This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details.
6 |
7 | import torch.nn as nn
8 |
9 |
10 | class LeNet5(nn.Module):
11 |
12 | def __init__(self):
13 | super(LeNet5, self).__init__()
14 |
15 | self.conv1 = nn.Conv2d(1, 6, kernel_size=(5, 5))
16 | self.relu1 = nn.ReLU()
17 | self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
18 | self.conv2 = nn.Conv2d(6, 16, kernel_size=(5, 5))
19 | self.relu2 = nn.ReLU()
20 | self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
21 | self.conv3 = nn.Conv2d(16, 120, kernel_size=(5, 5))
22 | self.relu3 = nn.ReLU()
23 | self.fc1 = nn.Linear(120, 84)
24 | self.relu4 = nn.ReLU()
25 | self.fc2 = nn.Linear(84, 10)
26 |
27 | def forward(self, img, out_feature=False):
28 | output = self.conv1(img)
29 | output = self.relu1(output)
30 | output = self.maxpool1(output)
31 | output = self.conv2(output)
32 | output = self.relu2(output)
33 | output = self.maxpool2(output)
34 | output = self.conv3(output)
35 | output = self.relu3(output)
36 | feature = output.view(-1, 120)
37 | output = self.fc1(feature)
38 | output = self.relu4(output)
39 | output = self.fc2(output)
40 | if out_feature == False:
41 | return output
42 | else:
43 | return output,feature
44 |
45 |
46 | class LeNet5Half(nn.Module):
47 |
48 | def __init__(self):
49 | super(LeNet5Half, self).__init__()
50 |
51 | self.conv1 = nn.Conv2d(1, 3, kernel_size=(5, 5))
52 | self.relu1 = nn.ReLU()
53 | self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
54 | self.conv2 = nn.Conv2d(3, 8, kernel_size=(5, 5))
55 | self.relu2 = nn.ReLU()
56 | self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=2)
57 | self.conv3 = nn.Conv2d(8, 60, kernel_size=(5, 5))
58 | self.relu3 = nn.ReLU()
59 | self.fc1 = nn.Linear(60, 42)
60 | self.relu4 = nn.ReLU()
61 | self.fc2 = nn.Linear(42, 10)
62 |
63 | def forward(self, img, out_feature=False):
64 | output = self.conv1(img)
65 | output = self.relu1(output)
66 | output = self.maxpool1(output)
67 | output = self.conv2(output)
68 | output = self.relu2(output)
69 | output = self.maxpool2(output)
70 | output = self.conv3(output)
71 | output = self.relu3(output)
72 | feature = output.view(-1, 60)
73 | output = self.fc1(feature)
74 | output = self.relu4(output)
75 | output = self.fc2(output)
76 | if out_feature == False:
77 | return output
78 | else:
79 | return output,feature
--------------------------------------------------------------------------------
/resnet.py:
--------------------------------------------------------------------------------
1 | # 2019.07.24-Changed output of forward function
2 | # Huawei Technologies Co., Ltd.
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class BasicBlock(nn.Module):
10 | expansion = 1
11 |
12 | def __init__(self, in_planes, planes, stride=1):
13 | super(BasicBlock, self).__init__()
14 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
15 | self.bn1 = nn.BatchNorm2d(planes)
16 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
17 | self.bn2 = nn.BatchNorm2d(planes)
18 |
19 | self.shortcut = nn.Sequential()
20 | if stride != 1 or in_planes != self.expansion*planes:
21 | self.shortcut = nn.Sequential(
22 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
23 | nn.BatchNorm2d(self.expansion*planes)
24 | )
25 |
26 | def forward(self, x):
27 | out = F.relu(self.bn1(self.conv1(x)))
28 | out = self.bn2(self.conv2(out))
29 | out += self.shortcut(x)
30 | out = F.relu(out)
31 | return out
32 |
33 |
34 | class Bottleneck(nn.Module):
35 | expansion = 4
36 |
37 | def __init__(self, in_planes, planes, stride=1):
38 | super(Bottleneck, self).__init__()
39 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
40 | self.bn1 = nn.BatchNorm2d(planes)
41 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
42 | self.bn2 = nn.BatchNorm2d(planes)
43 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
44 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
45 |
46 | self.shortcut = nn.Sequential()
47 | if stride != 1 or in_planes != self.expansion*planes:
48 | self.shortcut = nn.Sequential(
49 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
50 | nn.BatchNorm2d(self.expansion*planes)
51 | )
52 |
53 | def forward(self, x):
54 | out = F.relu(self.bn1(self.conv1(x)))
55 | out = F.relu(self.bn2(self.conv2(out)))
56 | out = self.bn3(self.conv3(out))
57 | out += self.shortcut(x)
58 | out = F.relu(out)
59 | return out
60 |
61 |
62 | class ResNet(nn.Module):
63 | def __init__(self, block, num_blocks, num_classes=10):
64 | super(ResNet, self).__init__()
65 | self.in_planes = 64
66 |
67 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
68 | self.bn1 = nn.BatchNorm2d(64)
69 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
70 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
71 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
72 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
73 | self.linear = nn.Linear(512*block.expansion, num_classes)
74 |
75 | def _make_layer(self, block, planes, num_blocks, stride):
76 | strides = [stride] + [1]*(num_blocks-1)
77 | layers = []
78 | for stride in strides:
79 | layers.append(block(self.in_planes, planes, stride))
80 | self.in_planes = planes * block.expansion
81 | return nn.Sequential(*layers)
82 |
83 | def forward(self, x, out_feature=False):
84 | out = F.relu(self.bn1(self.conv1(x)))
85 | out = self.layer1(out)
86 | out = self.layer2(out)
87 | out = self.layer3(out)
88 | out = self.layer4(out)
89 | out = F.avg_pool2d(out, 4)
90 | feature = out.view(out.size(0), -1)
91 | out = self.linear(feature)
92 | if out_feature == False:
93 | return out
94 | else:
95 | return out,feature
96 |
97 |
98 | def ResNet18(num_classes=10):
99 | return ResNet(BasicBlock, [2,2,2,2], num_classes)
100 |
101 | def ResNet34(num_classes=10):
102 | return ResNet(BasicBlock, [3,4,6,3], num_classes)
103 |
104 | def ResNet50(num_classes=10):
105 | return ResNet(Bottleneck, [3,4,6,3], num_classes)
106 |
107 | def ResNet101(num_classes=10):
108 | return ResNet(Bottleneck, [3,4,23,3], num_classes)
109 |
110 | def ResNet152(num_classes=10):
111 | return ResNet(Bottleneck, [3,8,36,3], num_classes)
112 |
113 |
--------------------------------------------------------------------------------
/teacher-train.py:
--------------------------------------------------------------------------------
1 | #Copyright (C) 2019. Huawei Technologies Co., Ltd. All rights reserved.
2 |
3 | #This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License.
4 |
5 | #This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details.
6 |
7 | import os
8 | from lenet import LeNet5
9 | import resnet
10 | import torch
11 | from torch.autograd import Variable
12 | from torchvision.datasets.mnist import MNIST
13 | from torchvision.datasets import CIFAR10
14 | from torchvision.datasets import CIFAR100
15 | import torchvision.transforms as transforms
16 | from torch.utils.data import DataLoader
17 | import argparse
18 |
19 | parser = argparse.ArgumentParser(description='train-teacher-network')
20 |
21 | # Basic model parameters.
22 | parser.add_argument('--dataset', type=str, default='MNIST', choices=['MNIST','cifar10','cifar100'])
23 | parser.add_argument('--data', type=str, default='/cache/data/')
24 | parser.add_argument('--output_dir', type=str, default='/cache/models/')
25 | args = parser.parse_args()
26 |
27 | os.makedirs(args.output_dir, exist_ok=True)
28 |
29 | acc = 0
30 | acc_best = 0
31 |
32 | if args.dataset == 'MNIST':
33 |
34 | data_train = MNIST(args.data,
35 | transform=transforms.Compose([
36 | transforms.Resize((32, 32)),
37 | transforms.ToTensor(),
38 | transforms.Normalize((0.1307,), (0.3081,))
39 | ]))
40 | data_test = MNIST(args.data,
41 | train=False,
42 | transform=transforms.Compose([
43 | transforms.Resize((32, 32)),
44 | transforms.ToTensor(),
45 | transforms.Normalize((0.1307,), (0.3081,))
46 | ]))
47 |
48 | data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)
49 | data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8)
50 |
51 | net = LeNet5().cuda()
52 | criterion = torch.nn.CrossEntropyLoss().cuda()
53 | optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
54 |
55 | if args.dataset == 'cifar10':
56 |
57 | transform_train = transforms.Compose([
58 | transforms.RandomCrop(32, padding=4),
59 | transforms.RandomHorizontalFlip(),
60 | transforms.ToTensor(),
61 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
62 | ])
63 |
64 | transform_test = transforms.Compose([
65 | transforms.ToTensor(),
66 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
67 | ])
68 |
69 | data_train = CIFAR10(args.data,
70 | transform=transform_train)
71 | data_test = CIFAR10(args.data,
72 | train=False,
73 | transform=transform_test)
74 |
75 | data_train_loader = DataLoader(data_train, batch_size=128, shuffle=True, num_workers=8)
76 | data_test_loader = DataLoader(data_test, batch_size=100, num_workers=0)
77 |
78 | net = resnet.ResNet34().cuda()
79 | criterion = torch.nn.CrossEntropyLoss().cuda()
80 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
81 |
82 | if args.dataset == 'cifar100':
83 |
84 | transform_train = transforms.Compose([
85 | transforms.RandomCrop(32, padding=4),
86 | transforms.RandomHorizontalFlip(),
87 | transforms.ToTensor(),
88 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
89 | ])
90 |
91 | transform_test = transforms.Compose([
92 | transforms.ToTensor(),
93 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
94 | ])
95 |
96 | data_train = CIFAR100(args.data,
97 | transform=transform_train)
98 | data_test = CIFAR100(args.data,
99 | train=False,
100 | transform=transform_test)
101 |
102 | data_train_loader = DataLoader(data_train, batch_size=128, shuffle=True, num_workers=0)
103 | data_test_loader = DataLoader(data_test, batch_size=128, num_workers=0)
104 | net = resnet.ResNet34(num_classes=100).cuda()
105 | criterion = torch.nn.CrossEntropyLoss().cuda()
106 | optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
107 |
108 |
109 | def adjust_learning_rate(optimizer, epoch):
110 | """For resnet, the lr starts from 0.1, and is divided by 10 at 80 and 120 epochs"""
111 | if epoch < 80:
112 | lr = 0.1
113 | elif epoch < 120:
114 | lr = 0.01
115 | else:
116 | lr = 0.001
117 | for param_group in optimizer.param_groups:
118 | param_group['lr'] = lr
119 |
120 | def train(epoch):
121 | if args.dataset != 'MNIST':
122 | adjust_learning_rate(optimizer, epoch)
123 | global cur_batch_win
124 | net.train()
125 | loss_list, batch_list = [], []
126 | for i, (images, labels) in enumerate(data_train_loader):
127 | images, labels = Variable(images).cuda(), Variable(labels).cuda()
128 |
129 | optimizer.zero_grad()
130 |
131 | output = net(images)
132 |
133 | loss = criterion(output, labels)
134 |
135 | loss_list.append(loss.data.item())
136 | batch_list.append(i+1)
137 |
138 | if i == 1:
139 | print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.data.item()))
140 |
141 | loss.backward()
142 | optimizer.step()
143 |
144 |
145 | def test():
146 | global acc, acc_best
147 | net.eval()
148 | total_correct = 0
149 | avg_loss = 0.0
150 | with torch.no_grad():
151 | for i, (images, labels) in enumerate(data_test_loader):
152 | images, labels = Variable(images).cuda(), Variable(labels).cuda()
153 | output = net(images)
154 | avg_loss += criterion(output, labels).sum()
155 | pred = output.data.max(1)[1]
156 | total_correct += pred.eq(labels.data.view_as(pred)).sum()
157 |
158 | avg_loss /= len(data_test)
159 | acc = float(total_correct) / len(data_test)
160 | if acc_best < acc:
161 | acc_best = acc
162 | print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.data.item(), acc))
163 |
164 |
165 | def train_and_test(epoch):
166 | train(epoch)
167 | test()
168 |
169 |
170 | def main():
171 | if args.dataset == 'MNIST':
172 | epoch = 10
173 | else:
174 | epoch = 200
175 | for e in range(1, epoch):
176 | train_and_test(e)
177 | torch.save(net,args.output_dir + 'teacher')
178 |
179 |
180 | if __name__ == '__main__':
181 | main()
--------------------------------------------------------------------------------
/DAFL-train.py:
--------------------------------------------------------------------------------
1 | #Copyright (C) 2019. Huawei Technologies Co., Ltd. All rights reserved.
2 |
3 | #This program is free software; you can redistribute it and/or modify it under the terms of the BSD 3-Clause License.
4 |
5 | #This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD 3-Clause License for more details.
6 |
7 | import argparse
8 | import os
9 | import numpy as np
10 | import math
11 | import sys
12 | import pdb
13 |
14 | import torchvision.transforms as transforms
15 |
16 | from torch.utils.data import DataLoader
17 | from torchvision import datasets
18 | from torch.autograd import Variable
19 |
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 | import torch
23 | from torchvision.datasets.mnist import MNIST
24 | from lenet import LeNet5Half
25 | from torchvision.datasets import CIFAR10
26 | from torchvision.datasets import CIFAR100
27 | import resnet
28 |
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument('--dataset', type=str, default='MNIST', choices=['MNIST','cifar10','cifar100'])
31 | parser.add_argument('--data', type=str, default='/cache/data/')
32 | parser.add_argument('--teacher_dir', type=str, default='/cache/models/')
33 | parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
34 | parser.add_argument('--batch_size', type=int, default=512, help='size of the batches')
35 | parser.add_argument('--lr_G', type=float, default=0.2, help='learning rate')
36 | parser.add_argument('--lr_S', type=float, default=2e-3, help='learning rate')
37 | parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
38 | parser.add_argument('--img_size', type=int, default=32, help='size of each image dimension')
39 | parser.add_argument('--channels', type=int, default=1, help='number of image channels')
40 | parser.add_argument('--oh', type=float, default=1, help='one hot loss')
41 | parser.add_argument('--ie', type=float, default=5, help='information entropy loss')
42 | parser.add_argument('--a', type=float, default=0.1, help='activation loss')
43 | parser.add_argument('--output_dir', type=str, default='/cache/models/')
44 |
45 | opt = parser.parse_args()
46 |
47 | img_shape = (opt.channels, opt.img_size, opt.img_size)
48 |
49 | cuda = True
50 |
51 | accr = 0
52 | accr_best = 0
53 |
54 | class Generator(nn.Module):
55 | def __init__(self):
56 | super(Generator, self).__init__()
57 |
58 | self.init_size = opt.img_size // 4
59 | self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128*self.init_size**2))
60 |
61 | self.conv_blocks0 = nn.Sequential(
62 | nn.BatchNorm2d(128),
63 | )
64 | self.conv_blocks1 = nn.Sequential(
65 | nn.Conv2d(128, 128, 3, stride=1, padding=1),
66 | nn.BatchNorm2d(128, 0.8),
67 | nn.LeakyReLU(0.2, inplace=True),
68 | )
69 | self.conv_blocks2 = nn.Sequential(
70 | nn.Conv2d(128, 64, 3, stride=1, padding=1),
71 | nn.BatchNorm2d(64, 0.8),
72 | nn.LeakyReLU(0.2, inplace=True),
73 | nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
74 | nn.Tanh(),
75 | nn.BatchNorm2d(opt.channels, affine=False)
76 | )
77 |
78 | def forward(self, z):
79 | out = self.l1(z)
80 | out = out.view(out.shape[0], 128, self.init_size, self.init_size)
81 | img = self.conv_blocks0(out)
82 | img = nn.functional.interpolate(img,scale_factor=2)
83 | img = self.conv_blocks1(img)
84 | img = nn.functional.interpolate(img,scale_factor=2)
85 | img = self.conv_blocks2(img)
86 | return img
87 |
88 | generator = Generator().cuda()
89 |
90 | teacher = torch.load(opt.teacher_dir + 'teacher').cuda()
91 | teacher.eval()
92 | criterion = torch.nn.CrossEntropyLoss().cuda()
93 |
94 | teacher = nn.DataParallel(teacher)
95 | generator = nn.DataParallel(generator)
96 |
97 | def kdloss(y, teacher_scores):
98 | p = F.log_softmax(y, dim=1)
99 | q = F.softmax(teacher_scores, dim=1)
100 | l_kl = F.kl_div(p, q, size_average=False) / y.shape[0]
101 | return l_kl
102 |
103 | if opt.dataset == 'MNIST':
104 | # Configure data loader
105 | net = LeNet5Half().cuda()
106 | net = nn.DataParallel(net)
107 | data_test = MNIST(opt.data,
108 | train=False,
109 | transform=transforms.Compose([
110 | transforms.Resize((32, 32)),
111 | transforms.ToTensor(),
112 | transforms.Normalize((0.1307,), (0.3081,))
113 | ]))
114 | data_test_loader = DataLoader(data_test, batch_size=64, num_workers=1, shuffle=False)
115 |
116 | # Optimizers
117 | optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_G)
118 | optimizer_S = torch.optim.Adam(net.parameters(), lr=opt.lr_S)
119 |
120 | if opt.dataset != 'MNIST':
121 | transform_test = transforms.Compose([
122 | transforms.ToTensor(),
123 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
124 | ])
125 | if opt.dataset == 'cifar10':
126 | net = resnet.ResNet18().cuda()
127 | net = nn.DataParallel(net)
128 | data_test = CIFAR10(opt.data,
129 | train=False,
130 | transform=transform_test)
131 | if opt.dataset == 'cifar100':
132 | net = resnet.ResNet18(num_classes=100).cuda()
133 | net = nn.DataParallel(net)
134 | data_test = CIFAR100(opt.data,
135 | train=False,
136 | transform=transform_test)
137 | data_test_loader = DataLoader(data_test, batch_size=opt.batch_size, num_workers=0)
138 |
139 | # Optimizers
140 | optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr_G)
141 |
142 | optimizer_S = torch.optim.SGD(net.parameters(), lr=opt.lr_S, momentum=0.9, weight_decay=5e-4)
143 |
144 |
145 | def adjust_learning_rate(optimizer, epoch, learing_rate):
146 | if epoch < 800:
147 | lr = learing_rate
148 | elif epoch < 1600:
149 | lr = 0.1*learing_rate
150 | else:
151 | lr = 0.01*learing_rate
152 | for param_group in optimizer.param_groups:
153 | param_group['lr'] = lr
154 |
155 |
156 | # ----------
157 | # Training
158 | # ----------
159 |
160 | batches_done = 0
161 | for epoch in range(opt.n_epochs):
162 |
163 | total_correct = 0
164 | avg_loss = 0.0
165 | if opt.dataset != 'MNIST':
166 | adjust_learning_rate(optimizer_S, epoch, opt.lr_S)
167 |
168 | for i in range(120):
169 | net.train()
170 | z = Variable(torch.randn(opt.batch_size, opt.latent_dim)).cuda()
171 | optimizer_G.zero_grad()
172 | optimizer_S.zero_grad()
173 | gen_imgs = generator(z)
174 | outputs_T, features_T = teacher(gen_imgs, out_feature=True)
175 | pred = outputs_T.data.max(1)[1]
176 | loss_activation = -features_T.abs().mean()
177 | loss_one_hot = criterion(outputs_T,pred)
178 | softmax_o_T = torch.nn.functional.softmax(outputs_T, dim = 1).mean(dim = 0)
179 | loss_information_entropy = (softmax_o_T * torch.log10(softmax_o_T)).sum()
180 | loss = loss_one_hot * opt.oh + loss_information_entropy * opt.ie + loss_activation * opt.a
181 | loss_kd = kdloss(net(gen_imgs.detach()), outputs_T.detach())
182 | loss += loss_kd
183 | loss.backward()
184 | optimizer_G.step()
185 | optimizer_S.step()
186 | if i == 1:
187 | print ("[Epoch %d/%d] [loss_oh: %f] [loss_ie: %f] [loss_a: %f] [loss_kd: %f]" % (epoch, opt.n_epochs,loss_one_hot.item(), loss_information_entropy.item(), loss_activation.item(), loss_kd.item()))
188 |
189 | with torch.no_grad():
190 | for i, (images, labels) in enumerate(data_test_loader):
191 | images = images.cuda()
192 | labels = labels.cuda()
193 | net.eval()
194 | output = net(images)
195 | avg_loss += criterion(output, labels).sum()
196 | pred = output.data.max(1)[1]
197 | total_correct += pred.eq(labels.data.view_as(pred)).sum()
198 |
199 | avg_loss /= len(data_test)
200 | print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.data.item(), float(total_correct) / len(data_test)))
201 | accr = round(float(total_correct) / len(data_test), 4)
202 | if accr > accr_best:
203 | torch.save(net,opt.output_dir + 'student')
204 | accr_best = accr
205 |
--------------------------------------------------------------------------------