├── README.md ├── model_factory.py ├── data_loader.py ├── plain_cnn_cifar.py ├── resnet_cifar.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | ## Densely Guided Knowledge Distillation using Multiple Teacher Assistants 2 | 3 |

4 | 5 |

6 | -------------------------------------------------------------------------------- /model_factory.py: -------------------------------------------------------------------------------- 1 | from resnet_cifar import * 2 | from plain_cnn_cifar import * 3 | 4 | 5 | def is_resnet(name): 6 | """ 7 | Simply checks if name represents a resnet, by convention, all resnet names start with 'resnet' 8 | :param name: 9 | :return: 10 | """ 11 | name = name.lower() 12 | return name.startswith('resnet') 13 | 14 | 15 | def create_cnn_model(name, dataset="cifar100", use_cuda=False): 16 | """ 17 | Create a student for training, given student name and dataset 18 | :param name: name of the student. e.g., resnet110, resnet32, plane2, plane10, ... 19 | :param dataset: the dataset which is used to determine last layer's output size. Options are cifar10 and cifar100. 20 | :return: a pytorch student for neural network 21 | """ 22 | num_classes = 100 if dataset == 'cifar100' else 10 23 | model = None 24 | if is_resnet(name): 25 | resnet_size = name[6:] 26 | resnet_model = resnet_book.get(resnet_size)(num_classes=num_classes) 27 | model = resnet_model 28 | 29 | else: 30 | plane_size = name[5:] 31 | model_spec = plane_cifar10_book.get(plane_size) if num_classes == 10 else plane_cifar100_book.get(plane_size) 32 | plane_model = ConvNetMaker(model_spec) 33 | model = plane_model 34 | 35 | # copy to cuda if activated 36 | if use_cuda: 37 | model = model.cuda() 38 | 39 | return model 40 | 41 | if __name__ == "__main__": 42 | dataset = 'cifar100' 43 | print('planes') 44 | for p in [2, 4, 6, 8, 10]: 45 | plane_name = "plane" + str(p) 46 | print(create_cnn_model(plane_name, dataset)) 47 | 48 | print('-'*20) 49 | print("resnets") 50 | for r in [8, 14, 20, 26, 32, 44, 56, 110]: 51 | resnet_name = "resnet" + str(r) 52 | print(create_cnn_model(resnet_name, dataset)) 53 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | import os 5 | 6 | NUM_WORKERS = os.cpu_count() 7 | 8 | 9 | def get_cifar(num_classes=100, dataset_dir='./data', batch_size=128, crop=False): 10 | """ 11 | :param num_classes: 10 for cifar10, 100 for cifar100 12 | :param dataset_dir: location of datasets, default is a directory named 'data' 13 | :param batch_size: batchsize, default to 128 14 | :param crop: whether or not use randomized horizontal crop, default to False 15 | :return: 16 | """ 17 | normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276]) 18 | simple_transform = transforms.Compose([transforms.ToTensor(), normalize]) 19 | 20 | if crop is True: 21 | train_transform = transforms.Compose([ 22 | transforms.RandomCrop(32, padding=4), 23 | transforms.RandomHorizontalFlip(), 24 | transforms.ToTensor(), 25 | normalize 26 | ]) 27 | else: 28 | train_transform = simple_transform 29 | 30 | if num_classes == 100: 31 | trainset = torchvision.datasets.CIFAR100(root=dataset_dir, train=True, 32 | download=True, transform=train_transform) 33 | 34 | testset = torchvision.datasets.CIFAR100(root=dataset_dir, train=False, 35 | download=True, transform=simple_transform) 36 | else: 37 | trainset = torchvision.datasets.CIFAR10(root=dataset_dir, train=True, 38 | download=True, transform=train_transform) 39 | 40 | testset = torchvision.datasets.CIFAR10(root=dataset_dir, train=False, 41 | download=True, transform=simple_transform) 42 | 43 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, num_workers=NUM_WORKERS, 44 | pin_memory=True, shuffle=True) 45 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, num_workers=NUM_WORKERS, 46 | pin_memory=True, shuffle=False) 47 | return trainloader, testloader 48 | 49 | 50 | if __name__ == "__main__": 51 | print("CIFAR10") 52 | print(get_cifar(10)) 53 | print("---"*20) 54 | print("---"*20) 55 | print("CIFAR100") 56 | print(get_cifar(100)) 57 | -------------------------------------------------------------------------------- /plain_cnn_cifar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | import torchvision.datasets as datasets 5 | import torchvision.transforms as transforms 6 | import torchvision.models as torch_models 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | 10 | 11 | class ConvNetMaker(nn.Module): 12 | """ 13 | Creates a simple (plane) convolutional neural network 14 | """ 15 | def __init__(self, layers): 16 | """ 17 | Makes a cnn using the provided list of layers specification 18 | The details of this list is available in the paper 19 | :param layers: a list of strings, representing layers like ["CB32", "CB32", "FC10"] 20 | """ 21 | super(ConvNetMaker, self).__init__() 22 | self.conv_layers = [] 23 | self.fc_layers = [] 24 | h, w, d = 32, 32, 3 25 | previous_layer_filter_count = 3 26 | previous_layer_size = h * w * d 27 | num_fc_layers_remained = len([1 for l in layers if l.startswith('FC')]) 28 | for layer in layers: 29 | if layer.startswith('Conv'): 30 | filter_count = int(layer[4:]) 31 | self.conv_layers += [nn.Conv2d(previous_layer_filter_count, filter_count, kernel_size=3, padding=1), 32 | nn.BatchNorm2d(filter_count), nn.ReLU(inplace=True)] 33 | previous_layer_filter_count = filter_count 34 | d = filter_count 35 | previous_layer_size = h * w * d 36 | elif layer.startswith('MaxPool'): 37 | self.conv_layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 38 | h, w = int(h / 2.0), int(w / 2.0) 39 | previous_layer_size = h * w * d 40 | elif layer.startswith('FC'): 41 | num_fc_layers_remained -= 1 42 | current_layer_size = int(layer[2:]) 43 | if num_fc_layers_remained == 0: 44 | self.fc_layers += [nn.Linear(previous_layer_size, current_layer_size)] 45 | else: 46 | self.fc_layers += [nn.Linear(previous_layer_size, current_layer_size), nn.ReLU(inplace=True)] 47 | previous_layer_size = current_layer_size 48 | 49 | conv_layers = self.conv_layers 50 | fc_layers = self.fc_layers 51 | self.conv_layers = nn.Sequential(*conv_layers) 52 | self.fc_layers = nn.Sequential(*fc_layers) 53 | 54 | def forward(self, x): 55 | x = self.conv_layers(x) 56 | x = x.view(x.size(0), -1) 57 | x = self.fc_layers(x) 58 | return x 59 | 60 | 61 | 62 | plane_cifar10_book = { 63 | '2': ['Conv16', 'MaxPool', 'Conv16', 'MaxPool', 'FC10'], 64 | '3': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'MaxPool', 'FC100'], 65 | '4': ['Conv16', 'Conv16', 'MaxPool', 'Conv32', 'Conv32', 'MaxPool', 'FC10'], 66 | '5': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'FC100'], 67 | '6': ['Conv16', 'Conv16', 'MaxPool', 'Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'FC10'], 68 | '7': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'Conv128', 'MaxPool', 'Conv256', 'MaxPool', 'FC64', 'FC100'], 69 | '8': ['Conv16', 'Conv16', 'MaxPool', 'Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'Conv128','MaxPool', 'FC64', 'FC10'], 70 | '9': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'Conv128', 'MaxPool', 'Conv256', 'Conv256', 'Conv256', 'MaxPool', 'FC512', 'FC100'], 71 | '10': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'Conv128', 'MaxPool', 'Conv256', 'Conv256', 'Conv256', 'Conv256', 'MaxPool', 'FC128','FC10'], 72 | } 73 | 74 | 75 | plane_cifar100_book = { 76 | '2': ['Conv32', 'MaxPool', 'Conv32', 'MaxPool', 'FC100'], 77 | '3': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'MaxPool', 'FC100'], 78 | '4': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'FC100'], 79 | '5': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'FC100'], 80 | '6': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool','Conv128', 'Conv128' ,'FC100'], 81 | '7': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'Conv128', 'MaxPool', 'Conv256', 'MaxPool', 'FC64', 'FC100'], 82 | '8': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'Conv128', 'MaxPool', 'Conv256', 'Conv256','MaxPool', 'FC64', 'FC100'], 83 | '9': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'Conv128', 'MaxPool', 'Conv256', 'Conv256', 'Conv256', 'MaxPool', 'FC512', 'FC100'], 84 | '10': ['Conv32', 'Conv32', 'MaxPool', 'Conv64', 'Conv64', 'MaxPool', 'Conv128', 'Conv128', 'MaxPool', 'Conv256', 'Conv256', 'Conv256', 'Conv256', 'MaxPool', 'FC512', 'FC100'], 85 | } -------------------------------------------------------------------------------- /resnet_cifar.py: -------------------------------------------------------------------------------- 1 | """ 2 | resnet for cifar in pytorch 3 | 4 | Reference: 5 | [1] K. He, X. Zhang, S. Ren, and J. Sun. Deep residual learning for image recognition. In CVPR, 2016. 6 | [2] K. He, X. Zhang, S. Ren, and J. Sun. Identity mappings in deep residual networks. In ECCV, 2016. 7 | """ 8 | import torch 9 | import torch.nn as nn 10 | import math 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 15 | 16 | 17 | class BasicBlock(nn.Module): 18 | expansion=1 19 | 20 | def __init__(self, inplanes, planes, stride=1, downsample=None): 21 | super(BasicBlock, self).__init__() 22 | self.conv1 = conv3x3(inplanes, planes, stride) 23 | self.bn1 = nn.BatchNorm2d(planes) 24 | self.relu = nn.ReLU(inplace=True) 25 | self.conv2 = conv3x3(planes, planes) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | self.downsample = downsample 28 | self.stride = stride 29 | 30 | def forward(self, x): 31 | residual = x 32 | 33 | out = self.conv1(x) 34 | out = self.bn1(out) 35 | out = self.relu(out) 36 | 37 | out = self.conv2(out) 38 | out = self.bn2(out) 39 | 40 | if self.downsample is not None: 41 | residual = self.downsample(x) 42 | 43 | out += residual 44 | out = self.relu(out) 45 | 46 | return out 47 | 48 | 49 | class Bottleneck(nn.Module): 50 | expansion=4 51 | 52 | def __init__(self, inplanes, planes, stride=1, downsample=None): 53 | super(Bottleneck, self).__init__() 54 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 55 | self.bn1 = nn.BatchNorm2d(planes) 56 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 57 | self.bn2 = nn.BatchNorm2d(planes) 58 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False) 59 | self.bn3 = nn.BatchNorm2d(planes*4) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.downsample = downsample 62 | self.stride = stride 63 | 64 | def forward(self, x): 65 | residual = x 66 | 67 | out = self.conv1(x) 68 | out = self.bn1(out) 69 | out = self.relu(out) 70 | 71 | out = self.conv2(out) 72 | out = self.bn2(out) 73 | out = self.relu(out) 74 | 75 | out = self.conv3(out) 76 | out = self.bn3(out) 77 | 78 | if self.downsample is not None: 79 | residual = self.downsample(x) 80 | 81 | out += residual 82 | out = self.relu(out) 83 | 84 | return out 85 | 86 | 87 | class PreActBasicBlock(nn.Module): 88 | expansion = 1 89 | 90 | def __init__(self, inplanes, planes, stride=1, downsample=None): 91 | super(PreActBasicBlock, self).__init__() 92 | self.bn1 = nn.BatchNorm2d(inplanes) 93 | self.relu = nn.ReLU(inplace=True) 94 | self.conv1 = conv3x3(inplanes, planes, stride) 95 | self.bn2 = nn.BatchNorm2d(planes) 96 | self.conv2 = conv3x3(planes, planes) 97 | self.downsample = downsample 98 | self.stride = stride 99 | 100 | def forward(self, x): 101 | residual = x 102 | 103 | out = self.bn1(x) 104 | out = self.relu(out) 105 | 106 | if self.downsample is not None: 107 | residual = self.downsample(out) 108 | 109 | out = self.conv1(out) 110 | 111 | out = self.bn2(out) 112 | out = self.relu(out) 113 | out = self.conv2(out) 114 | 115 | out += residual 116 | 117 | return out 118 | 119 | 120 | class PreActBottleneck(nn.Module): 121 | expansion = 4 122 | 123 | def __init__(self, inplanes, planes, stride=1, downsample=None): 124 | super(PreActBottleneck, self).__init__() 125 | self.bn1 = nn.BatchNorm2d(inplanes) 126 | self.relu = nn.ReLU(inplace=True) 127 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 128 | self.bn2 = nn.BatchNorm2d(planes) 129 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 130 | self.bn3 = nn.BatchNorm2d(planes) 131 | self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False) 132 | self.downsample = downsample 133 | self.stride = stride 134 | 135 | def forward(self, x): 136 | residual = x 137 | 138 | out = self.bn1(x) 139 | out = self.relu(out) 140 | 141 | if self.downsample is not None: 142 | residual = self.downsample(out) 143 | 144 | out = self.conv1(out) 145 | 146 | out = self.bn2(out) 147 | out = self.relu(out) 148 | out = self.conv2(out) 149 | 150 | out = self.bn3(out) 151 | out = self.relu(out) 152 | out = self.conv3(out) 153 | 154 | out += residual 155 | 156 | return out 157 | 158 | 159 | class ResNet_Cifar(nn.Module): 160 | 161 | def __init__(self, block, layers, num_classes=10): 162 | super(ResNet_Cifar, self).__init__() 163 | self.inplanes = 16 164 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 165 | self.bn1 = nn.BatchNorm2d(16) 166 | self.relu = nn.ReLU(inplace=True) 167 | self.layer1 = self._make_layer(block, 16, layers[0]) 168 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 169 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 170 | self.avgpool = nn.AvgPool2d(8, stride=1) 171 | self.fc = nn.Linear(64 * block.expansion, num_classes) 172 | 173 | for m in self.modules(): 174 | if isinstance(m, nn.Conv2d): 175 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 176 | m.weight.data.normal_(0, math.sqrt(2. / n)) 177 | elif isinstance(m, nn.BatchNorm2d): 178 | m.weight.data.fill_(1) 179 | m.bias.data.zero_() 180 | 181 | def _make_layer(self, block, planes, blocks, stride=1): 182 | downsample = None 183 | if stride != 1 or self.inplanes != planes * block.expansion: 184 | downsample = nn.Sequential( 185 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 186 | nn.BatchNorm2d(planes * block.expansion) 187 | ) 188 | 189 | layers = [] 190 | layers.append(block(self.inplanes, planes, stride, downsample)) 191 | self.inplanes = planes * block.expansion 192 | for _ in range(1, blocks): 193 | layers.append(block(self.inplanes, planes)) 194 | 195 | return nn.Sequential(*layers) 196 | 197 | def forward(self, x): 198 | x = self.conv1(x) 199 | x = self.bn1(x) 200 | x = self.relu(x) 201 | 202 | x = self.layer1(x) 203 | x = self.layer2(x) 204 | x = self.layer3(x) 205 | 206 | x = self.avgpool(x) 207 | x = x.view(x.size(0), -1) 208 | x = self.fc(x) 209 | 210 | return x 211 | 212 | 213 | class PreAct_ResNet_Cifar(nn.Module): 214 | 215 | def __init__(self, block, layers, num_classes=10): 216 | super(PreAct_ResNet_Cifar, self).__init__() 217 | self.inplanes = 16 218 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 219 | self.layer1 = self._make_layer(block, 16, layers[0]) 220 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 221 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 222 | self.bn = nn.BatchNorm2d(64*block.expansion) 223 | self.relu = nn.ReLU(inplace=True) 224 | self.avgpool = nn.AvgPool2d(8, stride=1) 225 | self.fc = nn.Linear(64*block.expansion, num_classes) 226 | 227 | for m in self.modules(): 228 | if isinstance(m, nn.Conv2d): 229 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 230 | m.weight.data.normal_(0, math.sqrt(2. / n)) 231 | elif isinstance(m, nn.BatchNorm2d): 232 | m.weight.data.fill_(1) 233 | m.bias.data.zero_() 234 | 235 | def _make_layer(self, block, planes, blocks, stride=1): 236 | downsample = None 237 | if stride != 1 or self.inplanes != planes*block.expansion: 238 | downsample = nn.Sequential( 239 | nn.Conv2d(self.inplanes, planes*block.expansion, kernel_size=1, stride=stride, bias=False) 240 | ) 241 | 242 | layers = [] 243 | layers.append(block(self.inplanes, planes, stride, downsample)) 244 | self.inplanes = planes*block.expansion 245 | for _ in range(1, blocks): 246 | layers.append(block(self.inplanes, planes)) 247 | return nn.Sequential(*layers) 248 | 249 | def forward(self, x): 250 | x = self.conv1(x) 251 | 252 | x = self.layer1(x) 253 | x = self.layer2(x) 254 | x = self.layer3(x) 255 | 256 | x = self.bn(x) 257 | x = self.relu(x) 258 | x = self.avgpool(x) 259 | x = x.view(x.size(0), -1) 260 | x = self.fc(x) 261 | 262 | return x 263 | 264 | 265 | 266 | def resnet14_cifar(**kwargs): 267 | model = ResNet_Cifar(BasicBlock, [2, 2, 2], **kwargs) 268 | return model 269 | 270 | def resnet8_cifar(**kwargs): 271 | model = ResNet_Cifar(BasicBlock, [1, 1, 1], **kwargs) 272 | return model 273 | 274 | 275 | def resnet20_cifar(**kwargs): 276 | model = ResNet_Cifar(BasicBlock, [3, 3, 3], **kwargs) 277 | return model 278 | 279 | def resnet26_cifar(**kwargs): 280 | model = ResNet_Cifar(BasicBlock, [4, 4, 4], **kwargs) 281 | return model 282 | 283 | def resnet32_cifar(**kwargs): 284 | model = ResNet_Cifar(BasicBlock, [5, 5, 5], **kwargs) 285 | return model 286 | 287 | 288 | def resnet44_cifar(**kwargs): 289 | model = ResNet_Cifar(BasicBlock, [7, 7, 7], **kwargs) 290 | return model 291 | 292 | 293 | def resnet56_cifar(**kwargs): 294 | model = ResNet_Cifar(BasicBlock, [9, 9, 9], **kwargs) 295 | return model 296 | 297 | 298 | def resnet110_cifar(**kwargs): 299 | model = ResNet_Cifar(BasicBlock, [18, 18, 18], **kwargs) 300 | return model 301 | 302 | 303 | def resnet1202_cifar(**kwargs): 304 | model = ResNet_Cifar(BasicBlock, [200, 200, 200], **kwargs) 305 | return model 306 | 307 | 308 | def resnet164_cifar(**kwargs): 309 | model = ResNet_Cifar(Bottleneck, [18, 18, 18], **kwargs) 310 | return model 311 | 312 | 313 | def resnet1001_cifar(**kwargs): 314 | model = ResNet_Cifar(Bottleneck, [111, 111, 111], **kwargs) 315 | return model 316 | 317 | 318 | def preact_resnet110_cifar(**kwargs): 319 | model = PreAct_ResNet_Cifar(PreActBasicBlock, [18, 18, 18], **kwargs) 320 | return model 321 | 322 | 323 | def preact_resnet164_cifar(**kwargs): 324 | model = PreAct_ResNet_Cifar(PreActBottleneck, [18, 18, 18], **kwargs) 325 | return model 326 | 327 | 328 | def preact_resnet1001_cifar(**kwargs): 329 | model = PreAct_ResNet_Cifar(PreActBottleneck, [111, 111, 111], **kwargs) 330 | return model 331 | 332 | resnet_book = { 333 | '8': resnet8_cifar, 334 | '14': resnet14_cifar, 335 | '20': resnet20_cifar, 336 | '26': resnet26_cifar, 337 | '32': resnet32_cifar, 338 | '44': resnet44_cifar, 339 | '56': resnet56_cifar, 340 | '110': resnet110_cifar, 341 | } 342 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | import argparse 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | from data_loader import get_cifar 9 | from model_factory import create_cnn_model, is_resnet 10 | import random 11 | 12 | def str2bool(v): 13 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 14 | return True 15 | else: 16 | return False 17 | 18 | 19 | def parse_arguments(): 20 | parser = argparse.ArgumentParser(description='TA Knowledge Distillation Code') 21 | parser.add_argument('--epochs', default=160, type=int, help='number of total epochs to run') 22 | parser.add_argument('--dataset', default='cifar100', type=str, help='dataset. can be either cifar10 or cifar100') 23 | parser.add_argument('--crop', default=False, type=str2bool, help='augmentation Ture or False') 24 | parser.add_argument('--batch-size', default=128, type=int, help='batch_size') 25 | parser.add_argument('--learning-rate', default=0.1, type=float, help='initial learning rate') 26 | parser.add_argument('--momentum', default=0.9, type=float, help='SGD momentum') 27 | parser.add_argument('--weight-decay', default=1e-4, type=float, help='SGD weight decay (default: 1e-4)') 28 | 29 | parser.add_argument('--T', default=5, type=int, help='T') 30 | parser.add_argument('--seed', default=20, type=int, help='seed') 31 | parser.add_argument('--lamb', default=1, type=float, help='lambda') 32 | 33 | parser.add_argument('--teacher', default='plane10', type=str, help='teacher name') 34 | parser.add_argument('--ta1', default='plane8', type=str) 35 | parser.add_argument('--ta2', default='plane6', type=str) 36 | parser.add_argument('--ta3', default='plane4', type=str) 37 | 38 | parser.add_argument('--teacher-checkpoint', default='/path', type=str) 39 | parser.add_argument('--ta1-checkpoint', default='/path', type=str) 40 | parser.add_argument('--ta2-checkpoint', default='/path', type=str) 41 | parser.add_argument('--ta3-checkpoint', default='/path', type=str) 42 | 43 | parser.add_argument('--student', default='plane2', type=str, help='student name') 44 | parser.add_argument('--TA-count', default=3, type=int, help='TA count') 45 | 46 | parser.add_argument('--cuda', default=True, type=str2bool, help='whether or not use cuda(train on GPU)') 47 | parser.add_argument('--gpus', default='0', type=str, help='Which GPUs you want to use? (0,1,2,3)') 48 | parser.add_argument('--drop-num', default=1, type=int, help='random drop') 49 | parser.add_argument('--dataset-dir', default='./data', type=str, help='dataset directory') 50 | args = parser.parse_args() 51 | return args 52 | 53 | 54 | def load_checkpoint(model, checkpoint_path): 55 | model_ckp = torch.load(checkpoint_path) 56 | model.load_state_dict(model_ckp['model_state_dict']) 57 | return model 58 | 59 | 60 | class TrainManager(object): 61 | def __init__(self, student, teacher=None, ta_list=None, train_loader=None, test_loader=None, train_config={}): 62 | self.student = student 63 | self.teacher = teacher 64 | for i, ta in enumerate(ta_list): 65 | globals()["self.ta{}".format(i + 1)] = ta 66 | 67 | self.have_teacher = bool(self.teacher) 68 | self.device = train_config['device'] 69 | self.name = train_config['name'] 70 | self.optimizer = optim.SGD(self.student.parameters(), 71 | lr=train_config['learning_rate'], 72 | momentum=train_config['momentum'], 73 | weight_decay=train_config['weight_decay']) 74 | self.teacher.eval() 75 | self.teacher.train(mode=False) 76 | for i, ta in enumerate(ta_list): 77 | globals()["self.ta{}".format(i + 1)].eval() 78 | globals()["self.ta{}".format(i + 1)].train(mode=False) 79 | 80 | self.train_loader = train_loader 81 | self.test_loader = test_loader 82 | self.config = train_config 83 | 84 | def train(self): 85 | lambda_ = self.config['lambda_student'] 86 | T = self.config['T_student'] 87 | epochs = self.config['epochs'] 88 | drop_num = self.config['drop_num'] 89 | 90 | iteration = 0 91 | best_acc = 0 92 | criterion = nn.CrossEntropyLoss() 93 | for epoch in range(epochs): 94 | self.student.train() 95 | self.adjust_learning_rate(self.optimizer, epoch) 96 | loss = 0 97 | for batch_idx, (data, target) in enumerate(self.train_loader): 98 | iteration += 1 99 | data = data.to(self.device) 100 | target = target.to(self.device) 101 | self.optimizer.zero_grad() 102 | student_output = self.student(data) 103 | 104 | # Standard Learning Loss (Classification Loss) 105 | loss_SL = criterion(student_output, target) 106 | 107 | teacher_outputs = self.teacher(data) 108 | ta_outputs = [] 109 | for i in range(len(ta_list)): 110 | ta_outputs.append(globals()["self.ta{}".format(i + 1)](data)) 111 | 112 | # Teacher Knowledge Distillation Loss 113 | loss_KD_list = [nn.KLDivLoss()(F.log_softmax(student_output / T, dim=1), 114 | F.softmax(teacher_outputs / T, dim=1))] 115 | 116 | # Teacher Assistants Knowledge Distillation Loss 117 | for i in range(len(ta_list)): 118 | loss_KD_list.append(nn.KLDivLoss()(F.log_softmax(student_output / T, dim=1), 119 | F.softmax(ta_outputs[i] / T, dim=1))) 120 | 121 | # Stochastic DGKD 122 | if args.drop_num != 0: 123 | for _ in range(args.drop_num): 124 | loss_KD_list.remove(random.choice(loss_KD_list)) 125 | 126 | # Total Loss 127 | loss = (1 - lambda_) * loss_SL + lambda_ * T * T * sum(loss_KD_list) 128 | 129 | loss.backward() 130 | self.optimizer.step() 131 | 132 | print("epoch {}/{}".format(epoch, epochs)) 133 | val_acc = self.validate(step=epoch) 134 | if val_acc > best_acc: 135 | best_acc = val_acc 136 | print('**** best val acc: ' + str(best_acc) + ' ****') 137 | self.save(epoch, name='DGKD_{}_{}_best.pth.tar'.format(args.gpus, self.name, args.dataset)) 138 | print('loss: ', loss.data) 139 | print() 140 | 141 | return best_acc 142 | 143 | def validate(self, step=0): 144 | self.student.eval() 145 | with torch.no_grad(): 146 | total = 0 147 | correct = 0 148 | 149 | for images, labels in self.test_loader: 150 | images = images.to(self.device) 151 | labels = labels.to(self.device) 152 | 153 | output = self.student(images) 154 | 155 | _, predicted = torch.max(output.data, 1) 156 | total += labels.size(0) 157 | correct += (predicted == labels).sum().item() 158 | acc = 100 * correct / total 159 | 160 | return acc 161 | 162 | def save(self, epoch, name=None): 163 | torch.save({ 164 | 'model_state_dict': self.student.state_dict(), 165 | 'optimizer_state_dict': self.optimizer.state_dict(), 166 | 'epoch': epoch, 167 | }, name) 168 | 169 | 170 | def adjust_learning_rate(self, optimizer, epoch): 171 | epochs = self.config['epochs'] 172 | models_are_plane = self.config['is_plane'] 173 | 174 | # depending on dataset 175 | if models_are_plane: 176 | lr = 0.01 177 | else: 178 | if epoch < int(epochs / 2.0): 179 | lr = 0.1 180 | elif epoch < int(epochs * 3 / 4.0): 181 | lr = 0.1 * 0.1 182 | else: 183 | lr = 0.1 * 0.01 184 | 185 | # update optimizer's learning rate 186 | for param_group in optimizer.param_groups: 187 | param_group['lr'] = lr 188 | 189 | 190 | if __name__ == "__main__": 191 | # Parsing arguments and prepare settings for training 192 | args = parse_arguments() 193 | print(args) 194 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus 195 | 196 | torch.cuda.manual_seed(args.seed) 197 | 198 | dataset = args.dataset 199 | num_classes = 100 if dataset == 'cifar100' else 10 200 | 201 | print("---------- Creating Students -------") 202 | student_model = create_cnn_model(args.student, dataset, use_cuda=args.cuda) 203 | 204 | train_config = { 205 | 'epochs': args.epochs, 206 | 'learning_rate': args.learning_rate, 207 | 'momentum': args.momentum, 208 | 'weight_decay': args.weight_decay, 209 | 'device': 'cuda' if args.cuda else 'cpu', 210 | 'is_plane': not is_resnet(args.student), 211 | 'T_student': args.T, 212 | 'lambda_student': args.lamb, 213 | 'drop_num': args.drop_num, 214 | } 215 | 216 | # Train Teacher if provided a teacher, otherwise it's a normal training using only cross entropy loss 217 | # This is for training single models for baselines models (or training the first teacher) 218 | if args.teacher: 219 | teacher_model = create_cnn_model(args.teacher, dataset, use_cuda=args.cuda) 220 | if args.teacher_checkpoint: 221 | print("---------- Loading Teacher -------") 222 | teacher_model = load_checkpoint(teacher_model, args.teacher_checkpoint) 223 | else: 224 | print("---------- Training Teacher -------") 225 | train_loader, test_loader = get_cifar(num_classes) 226 | teacher_train_config = copy.deepcopy(train_config) 227 | teacher_name = '{}_best.pth.tar'.format(args.teacher) 228 | teacher_train_config['name'] = args.teacher 229 | teacher_trainer = TrainManager(teacher_model, teacher=None, train_loader=train_loader, 230 | test_loader=test_loader, train_config=teacher_train_config) 231 | teacher_trainer.train() 232 | teacher_model = load_checkpoint(teacher_model, os.path.join('./', teacher_name)) 233 | 234 | # Prepare Teacher and Assistants 235 | print("---------- Creating Model ----------") 236 | teacher_model = create_cnn_model(args.teacher, dataset, use_cuda=args.cuda) 237 | models_dict = {} 238 | for i in range(1, args.TA_num + 1): 239 | models_dict['model{}'.format(i)] = create_cnn_model(getattr(args, 'ta{}'.format(i)), dataset, use_cuda=args.cuda) 240 | 241 | print("---------- Loading Model ----------") 242 | teacher_model = load_checkpoint(teacher_model, args.teacher_checkpoint) 243 | ta_list=[] 244 | for i in range(1, args.TA_num + 1): 245 | ta_list.append(load_checkpoint(models_dict['model{}'.format(i)], getattr(args, 'ta{}_checkpoint'.format(i)))) 246 | 247 | # Student training 248 | print("---------- Training Student -------") 249 | student_train_config = copy.deepcopy(train_config) 250 | train_loader, test_loader = get_cifar(num_classes, crop=args.crop) 251 | student_train_config['name'] = args.student 252 | student_trainer = TrainManager(student_model, teacher=teacher_model, ta_list=ta_list, 253 | train_loader=train_loader, 254 | test_loader=test_loader, 255 | train_config=student_train_config) 256 | 257 | best_student_acc = student_trainer.train() 258 | --------------------------------------------------------------------------------