├── README.md
├── model_factory.py
├── data_loader.py
├── plain_cnn_cifar.py
├── resnet_cifar.py
└── train.py
/README.md:
--------------------------------------------------------------------------------
1 | ## Densely Guided Knowledge Distillation using Multiple Teacher Assistants
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/model_factory.py:
--------------------------------------------------------------------------------
1 | from resnet_cifar import *
2 | from plain_cnn_cifar import *
3 |
4 |
5 | def is_resnet(name):
6 | """
7 | Simply checks if name represents a resnet, by convention, all resnet names start with 'resnet'
8 | :param name:
9 | :return:
10 | """
11 | name = name.lower()
12 | return name.startswith('resnet')
13 |
14 |
15 | def create_cnn_model(name, dataset="cifar100", use_cuda=False):
16 | """
17 | Create a student for training, given student name and dataset
18 | :param name: name of the student. e.g., resnet110, resnet32, plane2, plane10, ...
19 | :param dataset: the dataset which is used to determine last layer's output size. Options are cifar10 and cifar100.
20 | :return: a pytorch student for neural network
21 | """
22 | num_classes = 100 if dataset == 'cifar100' else 10
23 | model = None
24 | if is_resnet(name):
25 | resnet_size = name[6:]
26 | resnet_model = resnet_book.get(resnet_size)(num_classes=num_classes)
27 | model = resnet_model
28 |
29 | else:
30 | plane_size = name[5:]
31 | model_spec = plane_cifar10_book.get(plane_size) if num_classes == 10 else plane_cifar100_book.get(plane_size)
32 | plane_model = ConvNetMaker(model_spec)
33 | model = plane_model
34 |
35 | # copy to cuda if activated
36 | if use_cuda:
37 | model = model.cuda()
38 |
39 | return model
40 |
41 | if __name__ == "__main__":
42 | dataset = 'cifar100'
43 | print('planes')
44 | for p in [2, 4, 6, 8, 10]:
45 | plane_name = "plane" + str(p)
46 | print(create_cnn_model(plane_name, dataset))
47 |
48 | print('-'*20)
49 | print("resnets")
50 | for r in [8, 14, 20, 26, 32, 44, 56, 110]:
51 | resnet_name = "resnet" + str(r)
52 | print(create_cnn_model(resnet_name, dataset))
53 |
--------------------------------------------------------------------------------
/data_loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torchvision.transforms as transforms
4 | import os
5 |
6 | NUM_WORKERS = os.cpu_count()
7 |
8 |
9 | def get_cifar(num_classes=100, dataset_dir='./data', batch_size=128, crop=False):
10 | """
11 | :param num_classes: 10 for cifar10, 100 for cifar100
12 | :param dataset_dir: location of datasets, default is a directory named 'data'
13 | :param batch_size: batchsize, default to 128
14 | :param crop: whether or not use randomized horizontal crop, default to False
15 | :return:
16 | """
17 | normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
18 | simple_transform = transforms.Compose([transforms.ToTensor(), normalize])
19 |
20 | if crop is True:
21 | train_transform = transforms.Compose([
22 | transforms.RandomCrop(32, padding=4),
23 | transforms.RandomHorizontalFlip(),
24 | transforms.ToTensor(),
25 | normalize
26 | ])
27 | else:
28 | train_transform = simple_transform
29 |
30 | if num_classes == 100:
31 | trainset = torchvision.datasets.CIFAR100(root=dataset_dir, train=True,
32 | download=True, transform=train_transform)
33 |
34 | testset = torchvision.datasets.CIFAR100(root=dataset_dir, train=False,
35 | download=True, transform=simple_transform)
36 | else:
37 | trainset = torchvision.datasets.CIFAR10(root=dataset_dir, train=True,
38 | download=True, transform=train_transform)
39 |
40 | testset = torchvision.datasets.CIFAR10(root=dataset_dir, train=False,
41 | download=True, transform=simple_transform)
42 |
43 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, num_workers=NUM_WORKERS,
44 | pin_memory=True, shuffle=True)
45 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, num_workers=NUM_WORKERS,
46 | pin_memory=True, shuffle=False)
47 | return trainloader, testloader
48 |
49 |
50 | if __name__ == "__main__":
51 | print("CIFAR10")
52 | print(get_cifar(10))
53 | print("---"*20)
54 | print("---"*20)
55 | print("CIFAR100")
56 | print(get_cifar(100))
57 |
--------------------------------------------------------------------------------
/plain_cnn_cifar.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 | import torchvision.datasets as datasets
5 | import torchvision.transforms as transforms
6 | import torchvision.models as torch_models
7 | import torch.nn.functional as F
8 | import torch.optim as optim
9 |
10 |
11 | class ConvNetMaker(nn.Module):
12 | """
13 | Creates a simple (plane) convolutional neural network
14 | """
15 | def __init__(self, layers):
16 | """
17 | Makes a cnn using the provided list of layers specification
18 | The details of this list is available in the paper
19 | :param layers: a list of strings, representing layers like ["CB32", "CB32", "FC10"]
20 | """
21 | super(ConvNetMaker, self).__init__()
22 | self.conv_layers = []
23 | self.fc_layers = []
24 | h, w, d = 32, 32, 3
25 | previous_layer_filter_count = 3
26 | previous_layer_size = h * w * d
27 | num_fc_layers_remained = len([1 for l in layers if l.startswith('FC')])
28 | for layer in layers:
29 | if layer.startswith('Conv'):
30 | filter_count = int(layer[4:])
31 | self.conv_layers += [nn.Conv2d(previous_layer_filter_count, filter_count, kernel_size=3, padding=1),
32 | nn.BatchNorm2d(filter_count), nn.ReLU(inplace=True)]
33 | previous_layer_filter_count = filter_count
34 | d = filter_count
35 | previous_layer_size = h * w * d
36 | elif layer.startswith('MaxPool'):
37 | self.conv_layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
38 | h, w = int(h / 2.0), int(w / 2.0)
39 | previous_layer_size = h * w * d
40 | elif layer.startswith('FC'):
41 | num_fc_layers_remained -= 1
42 | current_layer_size = int(layer[2:])
43 | if num_fc_layers_remained == 0:
44 | self.fc_layers += [nn.Linear(previous_layer_size, current_layer_size)]
45 | else:
46 | self.fc_layers += [nn.Linear(previous_layer_size, current_layer_size), nn.ReLU(inplace=True)]
47 | previous_layer_size = current_layer_size
48 |
49 | conv_layers = self.conv_layers
50 | fc_layers = self.fc_layers
51 | self.conv_layers = nn.Sequential(*conv_layers)
52 | self.fc_layers = nn.Sequential(*fc_layers)
53 |
54 | def forward(self, x):
55 | x = self.conv_layers(x)
56 | x = x.view(x.size(0), -1)
57 | x = self.fc_layers(x)
58 | return x
59 |
60 |
61 |
62 | plane_cifar10_book = {
63 | '2': ['Conv16', 'MaxPool', 'Conv16', 'MaxPool', 'FC10'],
64 | '3': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'MaxPool', 'FC100'],
65 | '4': ['Conv16', 'Conv16', 'MaxPool', 'Conv32', 'Conv32', 'MaxPool', 'FC10'],
66 | '5': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'FC100'],
67 | '6': ['Conv16', 'Conv16', 'MaxPool', 'Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'FC10'],
68 | '7': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'Conv128', 'MaxPool', 'Conv256', 'MaxPool', 'FC64', 'FC100'],
69 | '8': ['Conv16', 'Conv16', 'MaxPool', 'Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'Conv128','MaxPool', 'FC64', 'FC10'],
70 | '9': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'Conv128', 'MaxPool', 'Conv256', 'Conv256', 'Conv256', 'MaxPool', 'FC512', 'FC100'],
71 | '10': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'Conv128', 'MaxPool', 'Conv256', 'Conv256', 'Conv256', 'Conv256', 'MaxPool', 'FC128','FC10'],
72 | }
73 |
74 |
75 | plane_cifar100_book = {
76 | '2': ['Conv32', 'MaxPool', 'Conv32', 'MaxPool', 'FC100'],
77 | '3': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'MaxPool', 'FC100'],
78 | '4': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'FC100'],
79 | '5': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'FC100'],
80 | '6': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool','Conv128', 'Conv128' ,'FC100'],
81 | '7': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'Conv128', 'MaxPool', 'Conv256', 'MaxPool', 'FC64', 'FC100'],
82 | '8': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'Conv128', 'MaxPool', 'Conv256', 'Conv256','MaxPool', 'FC64', 'FC100'],
83 | '9': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'Conv128', 'MaxPool', 'Conv256', 'Conv256', 'Conv256', 'MaxPool', 'FC512', 'FC100'],
84 | '10': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'Conv128', 'MaxPool', 'Conv256', 'Conv256', 'Conv256', 'Conv256', 'MaxPool', 'FC512', 'FC100'],
85 | }
--------------------------------------------------------------------------------
/resnet_cifar.py:
--------------------------------------------------------------------------------
1 | """
2 | resnet for cifar in pytorch
3 |
4 | Reference:
5 | [1] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016.
6 | [2] K. He, X. Zhang, S. Ren, and J. Sun. Identity mappings in deep residual networks. In ECCV, 2016.
7 | """
8 | import torch
9 | import torch.nn as nn
10 | import math
11 |
12 |
13 | def conv3x3(in_planes, out_planes, stride=1):
14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
15 |
16 |
17 | class BasicBlock(nn.Module):
18 | expansion=1
19 |
20 | def __init__(self, inplanes, planes, stride=1, downsample=None):
21 | super(BasicBlock, self).__init__()
22 | self.conv1 = conv3x3(inplanes, planes, stride)
23 | self.bn1 = nn.BatchNorm2d(planes)
24 | self.relu = nn.ReLU(inplace=True)
25 | self.conv2 = conv3x3(planes, planes)
26 | self.bn2 = nn.BatchNorm2d(planes)
27 | self.downsample = downsample
28 | self.stride = stride
29 |
30 | def forward(self, x):
31 | residual = x
32 |
33 | out = self.conv1(x)
34 | out = self.bn1(out)
35 | out = self.relu(out)
36 |
37 | out = self.conv2(out)
38 | out = self.bn2(out)
39 |
40 | if self.downsample is not None:
41 | residual = self.downsample(x)
42 |
43 | out += residual
44 | out = self.relu(out)
45 |
46 | return out
47 |
48 |
49 | class Bottleneck(nn.Module):
50 | expansion=4
51 |
52 | def __init__(self, inplanes, planes, stride=1, downsample=None):
53 | super(Bottleneck, self).__init__()
54 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
55 | self.bn1 = nn.BatchNorm2d(planes)
56 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
57 | self.bn2 = nn.BatchNorm2d(planes)
58 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False)
59 | self.bn3 = nn.BatchNorm2d(planes*4)
60 | self.relu = nn.ReLU(inplace=True)
61 | self.downsample = downsample
62 | self.stride = stride
63 |
64 | def forward(self, x):
65 | residual = x
66 |
67 | out = self.conv1(x)
68 | out = self.bn1(out)
69 | out = self.relu(out)
70 |
71 | out = self.conv2(out)
72 | out = self.bn2(out)
73 | out = self.relu(out)
74 |
75 | out = self.conv3(out)
76 | out = self.bn3(out)
77 |
78 | if self.downsample is not None:
79 | residual = self.downsample(x)
80 |
81 | out += residual
82 | out = self.relu(out)
83 |
84 | return out
85 |
86 |
87 | class PreActBasicBlock(nn.Module):
88 | expansion = 1
89 |
90 | def __init__(self, inplanes, planes, stride=1, downsample=None):
91 | super(PreActBasicBlock, self).__init__()
92 | self.bn1 = nn.BatchNorm2d(inplanes)
93 | self.relu = nn.ReLU(inplace=True)
94 | self.conv1 = conv3x3(inplanes, planes, stride)
95 | self.bn2 = nn.BatchNorm2d(planes)
96 | self.conv2 = conv3x3(planes, planes)
97 | self.downsample = downsample
98 | self.stride = stride
99 |
100 | def forward(self, x):
101 | residual = x
102 |
103 | out = self.bn1(x)
104 | out = self.relu(out)
105 |
106 | if self.downsample is not None:
107 | residual = self.downsample(out)
108 |
109 | out = self.conv1(out)
110 |
111 | out = self.bn2(out)
112 | out = self.relu(out)
113 | out = self.conv2(out)
114 |
115 | out += residual
116 |
117 | return out
118 |
119 |
120 | class PreActBottleneck(nn.Module):
121 | expansion = 4
122 |
123 | def __init__(self, inplanes, planes, stride=1, downsample=None):
124 | super(PreActBottleneck, self).__init__()
125 | self.bn1 = nn.BatchNorm2d(inplanes)
126 | self.relu = nn.ReLU(inplace=True)
127 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
128 | self.bn2 = nn.BatchNorm2d(planes)
129 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
130 | self.bn3 = nn.BatchNorm2d(planes)
131 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False)
132 | self.downsample = downsample
133 | self.stride = stride
134 |
135 | def forward(self, x):
136 | residual = x
137 |
138 | out = self.bn1(x)
139 | out = self.relu(out)
140 |
141 | if self.downsample is not None:
142 | residual = self.downsample(out)
143 |
144 | out = self.conv1(out)
145 |
146 | out = self.bn2(out)
147 | out = self.relu(out)
148 | out = self.conv2(out)
149 |
150 | out = self.bn3(out)
151 | out = self.relu(out)
152 | out = self.conv3(out)
153 |
154 | out += residual
155 |
156 | return out
157 |
158 |
159 | class ResNet_Cifar(nn.Module):
160 |
161 | def __init__(self, block, layers, num_classes=10):
162 | super(ResNet_Cifar, self).__init__()
163 | self.inplanes = 16
164 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
165 | self.bn1 = nn.BatchNorm2d(16)
166 | self.relu = nn.ReLU(inplace=True)
167 | self.layer1 = self._make_layer(block, 16, layers[0])
168 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
169 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
170 | self.avgpool = nn.AvgPool2d(8, stride=1)
171 | self.fc = nn.Linear(64 * block.expansion, num_classes)
172 |
173 | for m in self.modules():
174 | if isinstance(m, nn.Conv2d):
175 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
176 | m.weight.data.normal_(0, math.sqrt(2. / n))
177 | elif isinstance(m, nn.BatchNorm2d):
178 | m.weight.data.fill_(1)
179 | m.bias.data.zero_()
180 |
181 | def _make_layer(self, block, planes, blocks, stride=1):
182 | downsample = None
183 | if stride != 1 or self.inplanes != planes * block.expansion:
184 | downsample = nn.Sequential(
185 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
186 | nn.BatchNorm2d(planes * block.expansion)
187 | )
188 |
189 | layers = []
190 | layers.append(block(self.inplanes, planes, stride, downsample))
191 | self.inplanes = planes * block.expansion
192 | for _ in range(1, blocks):
193 | layers.append(block(self.inplanes, planes))
194 |
195 | return nn.Sequential(*layers)
196 |
197 | def forward(self, x):
198 | x = self.conv1(x)
199 | x = self.bn1(x)
200 | x = self.relu(x)
201 |
202 | x = self.layer1(x)
203 | x = self.layer2(x)
204 | x = self.layer3(x)
205 |
206 | x = self.avgpool(x)
207 | x = x.view(x.size(0), -1)
208 | x = self.fc(x)
209 |
210 | return x
211 |
212 |
213 | class PreAct_ResNet_Cifar(nn.Module):
214 |
215 | def __init__(self, block, layers, num_classes=10):
216 | super(PreAct_ResNet_Cifar, self).__init__()
217 | self.inplanes = 16
218 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
219 | self.layer1 = self._make_layer(block, 16, layers[0])
220 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
221 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
222 | self.bn = nn.BatchNorm2d(64*block.expansion)
223 | self.relu = nn.ReLU(inplace=True)
224 | self.avgpool = nn.AvgPool2d(8, stride=1)
225 | self.fc = nn.Linear(64*block.expansion, num_classes)
226 |
227 | for m in self.modules():
228 | if isinstance(m, nn.Conv2d):
229 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
230 | m.weight.data.normal_(0, math.sqrt(2. / n))
231 | elif isinstance(m, nn.BatchNorm2d):
232 | m.weight.data.fill_(1)
233 | m.bias.data.zero_()
234 |
235 | def _make_layer(self, block, planes, blocks, stride=1):
236 | downsample = None
237 | if stride != 1 or self.inplanes != planes*block.expansion:
238 | downsample = nn.Sequential(
239 | nn.Conv2d(self.inplanes, planes*block.expansion, kernel_size=1, stride=stride, bias=False)
240 | )
241 |
242 | layers = []
243 | layers.append(block(self.inplanes, planes, stride, downsample))
244 | self.inplanes = planes*block.expansion
245 | for _ in range(1, blocks):
246 | layers.append(block(self.inplanes, planes))
247 | return nn.Sequential(*layers)
248 |
249 | def forward(self, x):
250 | x = self.conv1(x)
251 |
252 | x = self.layer1(x)
253 | x = self.layer2(x)
254 | x = self.layer3(x)
255 |
256 | x = self.bn(x)
257 | x = self.relu(x)
258 | x = self.avgpool(x)
259 | x = x.view(x.size(0), -1)
260 | x = self.fc(x)
261 |
262 | return x
263 |
264 |
265 |
266 | def resnet14_cifar(**kwargs):
267 | model = ResNet_Cifar(BasicBlock, [2, 2, 2], **kwargs)
268 | return model
269 |
270 | def resnet8_cifar(**kwargs):
271 | model = ResNet_Cifar(BasicBlock, [1, 1, 1], **kwargs)
272 | return model
273 |
274 |
275 | def resnet20_cifar(**kwargs):
276 | model = ResNet_Cifar(BasicBlock, [3, 3, 3], **kwargs)
277 | return model
278 |
279 | def resnet26_cifar(**kwargs):
280 | model = ResNet_Cifar(BasicBlock, [4, 4, 4], **kwargs)
281 | return model
282 |
283 | def resnet32_cifar(**kwargs):
284 | model = ResNet_Cifar(BasicBlock, [5, 5, 5], **kwargs)
285 | return model
286 |
287 |
288 | def resnet44_cifar(**kwargs):
289 | model = ResNet_Cifar(BasicBlock, [7, 7, 7], **kwargs)
290 | return model
291 |
292 |
293 | def resnet56_cifar(**kwargs):
294 | model = ResNet_Cifar(BasicBlock, [9, 9, 9], **kwargs)
295 | return model
296 |
297 |
298 | def resnet110_cifar(**kwargs):
299 | model = ResNet_Cifar(BasicBlock, [18, 18, 18], **kwargs)
300 | return model
301 |
302 |
303 | def resnet1202_cifar(**kwargs):
304 | model = ResNet_Cifar(BasicBlock, [200, 200, 200], **kwargs)
305 | return model
306 |
307 |
308 | def resnet164_cifar(**kwargs):
309 | model = ResNet_Cifar(Bottleneck, [18, 18, 18], **kwargs)
310 | return model
311 |
312 |
313 | def resnet1001_cifar(**kwargs):
314 | model = ResNet_Cifar(Bottleneck, [111, 111, 111], **kwargs)
315 | return model
316 |
317 |
318 | def preact_resnet110_cifar(**kwargs):
319 | model = PreAct_ResNet_Cifar(PreActBasicBlock, [18, 18, 18], **kwargs)
320 | return model
321 |
322 |
323 | def preact_resnet164_cifar(**kwargs):
324 | model = PreAct_ResNet_Cifar(PreActBottleneck, [18, 18, 18], **kwargs)
325 | return model
326 |
327 |
328 | def preact_resnet1001_cifar(**kwargs):
329 | model = PreAct_ResNet_Cifar(PreActBottleneck, [111, 111, 111], **kwargs)
330 | return model
331 |
332 | resnet_book = {
333 | '8': resnet8_cifar,
334 | '14': resnet14_cifar,
335 | '20': resnet20_cifar,
336 | '26': resnet26_cifar,
337 | '32': resnet32_cifar,
338 | '44': resnet44_cifar,
339 | '56': resnet56_cifar,
340 | '110': resnet110_cifar,
341 | }
342 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import torch
4 | import argparse
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | import torch.nn.functional as F
8 | from data_loader import get_cifar
9 | from model_factory import create_cnn_model, is_resnet
10 | import random
11 |
12 | def str2bool(v):
13 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
14 | return True
15 | else:
16 | return False
17 |
18 |
19 | def parse_arguments():
20 | parser = argparse.ArgumentParser(description='TA Knowledge Distillation Code')
21 | parser.add_argument('--epochs', default=160, type=int, help='number of total epochs to run')
22 | parser.add_argument('--dataset', default='cifar100', type=str, help='dataset. can be either cifar10 or cifar100')
23 | parser.add_argument('--crop', default=False, type=str2bool, help='augmentation Ture or False')
24 | parser.add_argument('--batch-size', default=128, type=int, help='batch_size')
25 | parser.add_argument('--learning-rate', default=0.1, type=float, help='initial learning rate')
26 | parser.add_argument('--momentum', default=0.9, type=float, help='SGD momentum')
27 | parser.add_argument('--weight-decay', default=1e-4, type=float, help='SGD weight decay (default: 1e-4)')
28 |
29 | parser.add_argument('--T', default=5, type=int, help='T')
30 | parser.add_argument('--seed', default=20, type=int, help='seed')
31 | parser.add_argument('--lamb', default=1, type=float, help='lambda')
32 |
33 | parser.add_argument('--teacher', default='plane10', type=str, help='teacher name')
34 | parser.add_argument('--ta1', default='plane8', type=str)
35 | parser.add_argument('--ta2', default='plane6', type=str)
36 | parser.add_argument('--ta3', default='plane4', type=str)
37 |
38 | parser.add_argument('--teacher-checkpoint', default='/path', type=str)
39 | parser.add_argument('--ta1-checkpoint', default='/path', type=str)
40 | parser.add_argument('--ta2-checkpoint', default='/path', type=str)
41 | parser.add_argument('--ta3-checkpoint', default='/path', type=str)
42 |
43 | parser.add_argument('--student', default='plane2', type=str, help='student name')
44 | parser.add_argument('--TA-count', default=3, type=int, help='TA count')
45 |
46 | parser.add_argument('--cuda', default=True, type=str2bool, help='whether or not use cuda(train on GPU)')
47 | parser.add_argument('--gpus', default='0', type=str, help='Which GPUs you want to use? (0,1,2,3)')
48 | parser.add_argument('--drop-num', default=1, type=int, help='random drop')
49 | parser.add_argument('--dataset-dir', default='./data', type=str, help='dataset directory')
50 | args = parser.parse_args()
51 | return args
52 |
53 |
54 | def load_checkpoint(model, checkpoint_path):
55 | model_ckp = torch.load(checkpoint_path)
56 | model.load_state_dict(model_ckp['model_state_dict'])
57 | return model
58 |
59 |
60 | class TrainManager(object):
61 | def __init__(self, student, teacher=None, ta_list=None, train_loader=None, test_loader=None, train_config={}):
62 | self.student = student
63 | self.teacher = teacher
64 | for i, ta in enumerate(ta_list):
65 | globals()["self.ta{}".format(i + 1)] = ta
66 |
67 | self.have_teacher = bool(self.teacher)
68 | self.device = train_config['device']
69 | self.name = train_config['name']
70 | self.optimizer = optim.SGD(self.student.parameters(),
71 | lr=train_config['learning_rate'],
72 | momentum=train_config['momentum'],
73 | weight_decay=train_config['weight_decay'])
74 | self.teacher.eval()
75 | self.teacher.train(mode=False)
76 | for i, ta in enumerate(ta_list):
77 | globals()["self.ta{}".format(i + 1)].eval()
78 | globals()["self.ta{}".format(i + 1)].train(mode=False)
79 |
80 | self.train_loader = train_loader
81 | self.test_loader = test_loader
82 | self.config = train_config
83 |
84 | def train(self):
85 | lambda_ = self.config['lambda_student']
86 | T = self.config['T_student']
87 | epochs = self.config['epochs']
88 | drop_num = self.config['drop_num']
89 |
90 | iteration = 0
91 | best_acc = 0
92 | criterion = nn.CrossEntropyLoss()
93 | for epoch in range(epochs):
94 | self.student.train()
95 | self.adjust_learning_rate(self.optimizer, epoch)
96 | loss = 0
97 | for batch_idx, (data, target) in enumerate(self.train_loader):
98 | iteration += 1
99 | data = data.to(self.device)
100 | target = target.to(self.device)
101 | self.optimizer.zero_grad()
102 | student_output = self.student(data)
103 |
104 | # Standard Learning Loss (Classification Loss)
105 | loss_SL = criterion(student_output, target)
106 |
107 | teacher_outputs = self.teacher(data)
108 | ta_outputs = []
109 | for i in range(len(ta_list)):
110 | ta_outputs.append(globals()["self.ta{}".format(i + 1)](data))
111 |
112 | # Teacher Knowledge Distillation Loss
113 | loss_KD_list = [nn.KLDivLoss()(F.log_softmax(student_output / T, dim=1),
114 | F.softmax(teacher_outputs / T, dim=1))]
115 |
116 | # Teacher Assistants Knowledge Distillation Loss
117 | for i in range(len(ta_list)):
118 | loss_KD_list.append(nn.KLDivLoss()(F.log_softmax(student_output / T, dim=1),
119 | F.softmax(ta_outputs[i] / T, dim=1)))
120 |
121 | # Stochastic DGKD
122 | if args.drop_num != 0:
123 | for _ in range(args.drop_num):
124 | loss_KD_list.remove(random.choice(loss_KD_list))
125 |
126 | # Total Loss
127 | loss = (1 - lambda_) * loss_SL + lambda_ * T * T * sum(loss_KD_list)
128 |
129 | loss.backward()
130 | self.optimizer.step()
131 |
132 | print("epoch {}/{}".format(epoch, epochs))
133 | val_acc = self.validate(step=epoch)
134 | if val_acc > best_acc:
135 | best_acc = val_acc
136 | print('**** best val acc: ' + str(best_acc) + ' ****')
137 | self.save(epoch, name='DGKD_{}_{}_best.pth.tar'.format(args.gpus, self.name, args.dataset))
138 | print('loss: ', loss.data)
139 | print()
140 |
141 | return best_acc
142 |
143 | def validate(self, step=0):
144 | self.student.eval()
145 | with torch.no_grad():
146 | total = 0
147 | correct = 0
148 |
149 | for images, labels in self.test_loader:
150 | images = images.to(self.device)
151 | labels = labels.to(self.device)
152 |
153 | output = self.student(images)
154 |
155 | _, predicted = torch.max(output.data, 1)
156 | total += labels.size(0)
157 | correct += (predicted == labels).sum().item()
158 | acc = 100 * correct / total
159 |
160 | return acc
161 |
162 | def save(self, epoch, name=None):
163 | torch.save({
164 | 'model_state_dict': self.student.state_dict(),
165 | 'optimizer_state_dict': self.optimizer.state_dict(),
166 | 'epoch': epoch,
167 | }, name)
168 |
169 |
170 | def adjust_learning_rate(self, optimizer, epoch):
171 | epochs = self.config['epochs']
172 | models_are_plane = self.config['is_plane']
173 |
174 | # depending on dataset
175 | if models_are_plane:
176 | lr = 0.01
177 | else:
178 | if epoch < int(epochs / 2.0):
179 | lr = 0.1
180 | elif epoch < int(epochs * 3 / 4.0):
181 | lr = 0.1 * 0.1
182 | else:
183 | lr = 0.1 * 0.01
184 |
185 | # update optimizer's learning rate
186 | for param_group in optimizer.param_groups:
187 | param_group['lr'] = lr
188 |
189 |
190 | if __name__ == "__main__":
191 | # Parsing arguments and prepare settings for training
192 | args = parse_arguments()
193 | print(args)
194 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
195 |
196 | torch.cuda.manual_seed(args.seed)
197 |
198 | dataset = args.dataset
199 | num_classes = 100 if dataset == 'cifar100' else 10
200 |
201 | print("---------- Creating Students -------")
202 | student_model = create_cnn_model(args.student, dataset, use_cuda=args.cuda)
203 |
204 | train_config = {
205 | 'epochs': args.epochs,
206 | 'learning_rate': args.learning_rate,
207 | 'momentum': args.momentum,
208 | 'weight_decay': args.weight_decay,
209 | 'device': 'cuda' if args.cuda else 'cpu',
210 | 'is_plane': not is_resnet(args.student),
211 | 'T_student': args.T,
212 | 'lambda_student': args.lamb,
213 | 'drop_num': args.drop_num,
214 | }
215 |
216 | # Train Teacher if provided a teacher, otherwise it's a normal training using only cross entropy loss
217 | # This is for training single models for baselines models (or training the first teacher)
218 | if args.teacher:
219 | teacher_model = create_cnn_model(args.teacher, dataset, use_cuda=args.cuda)
220 | if args.teacher_checkpoint:
221 | print("---------- Loading Teacher -------")
222 | teacher_model = load_checkpoint(teacher_model, args.teacher_checkpoint)
223 | else:
224 | print("---------- Training Teacher -------")
225 | train_loader, test_loader = get_cifar(num_classes)
226 | teacher_train_config = copy.deepcopy(train_config)
227 | teacher_name = '{}_best.pth.tar'.format(args.teacher)
228 | teacher_train_config['name'] = args.teacher
229 | teacher_trainer = TrainManager(teacher_model, teacher=None, train_loader=train_loader,
230 | test_loader=test_loader, train_config=teacher_train_config)
231 | teacher_trainer.train()
232 | teacher_model = load_checkpoint(teacher_model, os.path.join('./', teacher_name))
233 |
234 | # Prepare Teacher and Assistants
235 | print("---------- Creating Model ----------")
236 | teacher_model = create_cnn_model(args.teacher, dataset, use_cuda=args.cuda)
237 | models_dict = {}
238 | for i in range(1, args.TA_num + 1):
239 | models_dict['model{}'.format(i)] = create_cnn_model(getattr(args, 'ta{}'.format(i)), dataset, use_cuda=args.cuda)
240 |
241 | print("---------- Loading Model ----------")
242 | teacher_model = load_checkpoint(teacher_model, args.teacher_checkpoint)
243 | ta_list=[]
244 | for i in range(1, args.TA_num + 1):
245 | ta_list.append(load_checkpoint(models_dict['model{}'.format(i)], getattr(args, 'ta{}_checkpoint'.format(i))))
246 |
247 | # Student training
248 | print("---------- Training Student -------")
249 | student_train_config = copy.deepcopy(train_config)
250 | train_loader, test_loader = get_cifar(num_classes, crop=args.crop)
251 | student_train_config['name'] = args.student
252 | student_trainer = TrainManager(student_model, teacher=teacher_model, ta_list=ta_list,
253 | train_loader=train_loader,
254 | test_loader=test_loader,
255 | train_config=student_train_config)
256 |
257 | best_student_acc = student_trainer.train()
258 |
--------------------------------------------------------------------------------