├── Poster-SCAN-v0.1.pptx ├── README.md ├── inference.py ├── sresnet.py └── train.py /Poster-SCAN-v0.1.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArchipLab-LinfengZhang/pytorch-scalable-neural-networks/aac136e1a009b95e75ba4cc7560902b54220efde/Poster-SCAN-v0.1.pptx -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SCAN: A Scalabel Neural Networks Framework for Compact and Efficient Models 2 | A pytorch implementation of paper *SCAN: A Scalabel Neural Networks Framework for Compact and Efficient Models*. 3 | An advanced version has been released in https://github.com/ArchipLab-LinfengZhang/pytorch-self-distillation-final. 4 | 5 | ## Requirements 6 | Install PyTorch>=1.0.0, torchvision>=0.2.0. 7 | 8 | Download and process the CIFAR datasets by torchvision. 9 | 10 | ## How to train 11 | python train.py [--depth=18] [--class_num=100] [--epoch=200] [--lambda_KD=0.5] 12 | **depth** indicates the number of layers in resnet. 13 | 14 | **class_num** decides which dataset will be used (cifar10/100). 15 | 16 | **epoch** indicates how many epoches will be utilized to train this model. 17 | 18 | **lambda_KD** is a hyper-parameter for balancing distillation loss and cross entropy loss. 19 | 20 | ## Dynamatic inference 21 | 22 | python inference.py [--depth=18] 23 | Only a pre-trained ResNet18 model is prepared now, stored in **model** folder. inference.py will use it to inference, and print its accuracy and acceleration ratio. By adjusting the thresholds in line30 in inference.py, you can get different accuracy and acceleration results. 24 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import argparse 7 | import sresnet 8 | import torch.nn.functional as F 9 | 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | torch.manual_seed(100) 12 | torch.cuda.manual_seed(100) 13 | 14 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 15 | parser.add_argument('--depth', default=18, type=int) 16 | parser.add_argument('--class_num', default=100, type=int) 17 | parser.add_argument('--epoch', default=200, type=int) 18 | parser.add_argument('--lambda_KD', default=0.5, type=float) 19 | args = parser.parse_args() 20 | print(args) 21 | 22 | 23 | def CrossEntropy(outputs, targets): 24 | log_softmax_outputs = F.log_softmax(outputs/3.0, dim=1) 25 | softmax_targets = F.softmax(targets/3.0, dim=1) 26 | return -(log_softmax_outputs * softmax_targets).sum(dim=1).mean() 27 | 28 | 29 | def judge(tensor, c): 30 | dic = {0: 0.98, 1: 0.97, 2: 0.98, 3: 0.95} 31 | maxium = torch.max(tensor) 32 | if float(maxium) > dic[c]: 33 | return True 34 | else: 35 | 36 | return False 37 | 38 | 39 | BATCH_SIZE = 256 40 | LR = 0.1 41 | 42 | transform_train = transforms.Compose([ 43 | transforms.RandomCrop(32, padding=4), 44 | transforms.RandomHorizontalFlip(), 45 | transforms.ToTensor(), 46 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 47 | ]) 48 | transform_test = transforms.Compose([ 49 | transforms.ToTensor(), 50 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 51 | ]) 52 | trainset, testset = None, None 53 | if args.class_num == 100: 54 | print("dataset: CIFAR100") 55 | trainset = torchvision.datasets.CIFAR100( 56 | root='/home2/lthpc/data', 57 | train=True, 58 | download=False, 59 | transform=transform_train 60 | ) 61 | testset = torchvision.datasets.CIFAR100( 62 | root='/home2/lthpc/data', 63 | train=False, 64 | download=False, 65 | transform=transform_test 66 | ) 67 | if args.class_num == 10: 68 | print("dataset: CIFAR10") 69 | trainset = torchvision.datasets.CIFAR10( 70 | root='./data', 71 | train=True, 72 | download=False, 73 | transform=transform_train 74 | ) 75 | testset = torchvision.datasets.CIFAR10( 76 | root='./data', 77 | train=False, 78 | download=False, 79 | transform=transform_test 80 | ) 81 | trainloader = torch.utils.data.DataLoader( 82 | trainset, 83 | batch_size=BATCH_SIZE, 84 | shuffle=True, 85 | num_workers=4 86 | ) 87 | testloader = torch.utils.data.DataLoader( 88 | testset, 89 | batch_size=BATCH_SIZE, 90 | shuffle=False, 91 | num_workers=4 92 | ) 93 | 94 | net = None 95 | if args.depth == 18: 96 | net = sresnet.resnet18(num_classes=args.class_num, align="CONV") 97 | print("using resnet 18") 98 | if args.depth == 50: 99 | net = sresnet.resnet50(num_classes=args.class_num, align="CONV") 100 | print("using resnet 50") 101 | if args.depth == 101: 102 | net = sresnet.resnet101(num_classes=args.class_num, align="CONV") 103 | print("using resnet 101") 104 | if args.depth == 152: 105 | net = sresnet.resnet152(num_classes=args.class_num, align="CONV") 106 | print("using resnet 152") 107 | 108 | net.to(device) 109 | net.load_state_dict(torch.load("bestmodel.pth")) 110 | 111 | 112 | if __name__ == "__main__": 113 | best_acc = 0 114 | caught = [0, 0, 0, 0, 0] 115 | print("Waiting Test!") 116 | with torch.no_grad(): 117 | correct4, correct3, correct2, correct1, correct0 = 0, 0, 0, 0, 0 118 | predicted4, predicted3, predicted2, predicted1, predicted0 = 0, 0, 0, 0, 0 119 | correct = 0.0 120 | total = 0.0 121 | right = 0 122 | for data in testloader: 123 | net.eval() 124 | images, labels = data 125 | images, labels = images.to(device), labels.to(device) 126 | outputs, feature_loss = net(images) 127 | ensemble = sum(outputs) / len(outputs) 128 | outputs.reverse() 129 | 130 | for index in range(len(outputs)): 131 | outputs[index] = F.softmax(outputs[index]) 132 | 133 | for index in range(images.size(0)): 134 | ok = False 135 | for c in range(4): 136 | logits = outputs[c][index] 137 | if judge(logits, c): 138 | caught[c] += 1 139 | predict = torch.argmax(logits) 140 | if predict.cpu().numpy().item() == labels[index]: 141 | right += 1 142 | 143 | ok = True 144 | break 145 | 146 | if not ok: 147 | caught[-1] += 1 148 | # print(index, "ensemble") 149 | logits = ensemble[index] 150 | predict = torch.argmax(logits) 151 | if predict.cpu().numpy().item() == labels[index]: 152 | right += 1 153 | 154 | total += float(labels.size(0)) 155 | print('Test Set Accuracy: %.4f%% ' % (100 * right / total)) 156 | acceleration_ratio = 1/((0.32 * caught[0] + 0.53* caught[1] + 0.76*caught[2] + 1.0 * caught[3] + 1.07 * caught[4])/total) 157 | 158 | print("Acceleration ratio:", acceleration_ratio) 159 | 160 | 161 | 162 | -------------------------------------------------------------------------------- /sresnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | 9 | model_urls = { 10 | 'resnet18': './pretrain/resnet18-5c106cde.pth', 11 | 'resnet34': './pretrain/resnet34-333f7ec4.pth', 12 | 'resnet50': './pretrain/resnet50-19c8e357.pth', 13 | 'resnet101': './pretrain/resnet101-5d3b4d8f.pth', 14 | 'resnet152': './pretrain/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | def conv1x1(in_planes, out_planes, stride=1): 25 | """1x1 convolution""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None): 33 | super(BasicBlock, self).__init__() 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | identity = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | identity = self.downsample(x) 54 | 55 | out += identity 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None): 65 | super(Bottleneck, self).__init__() 66 | self.conv1 = conv1x1(inplanes, planes) 67 | self.bn1 = nn.BatchNorm2d(planes) 68 | self.conv2 = conv3x3(planes, planes, stride) 69 | self.bn2 = nn.BatchNorm2d(planes) 70 | self.conv3 = conv1x1(planes, planes * self.expansion) 71 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.stride = stride 75 | 76 | def forward(self, x): 77 | identity = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv3(out) 88 | out = self.bn3(out) 89 | 90 | if self.downsample is not None: 91 | identity = self.downsample(x) 92 | 93 | out += identity 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | 99 | def ScalaNet(channel_in, channel_out, size): 100 | return nn.Sequential( 101 | nn.Conv2d(channel_in, 128, kernel_size=1, stride=1), 102 | nn.BatchNorm2d(128), 103 | nn.ReLU(), 104 | nn.Conv2d(128, 128, kernel_size=size, stride=size), 105 | nn.BatchNorm2d(128), 106 | nn.ReLU(), 107 | nn.Conv2d(128, channel_out, kernel_size=1, stride=1), 108 | nn.BatchNorm2d(channel_out), 109 | nn.ReLU(), 110 | nn.AvgPool2d(4, 4) 111 | ) 112 | 113 | 114 | class ResNet(nn.Module): 115 | 116 | def __init__(self, block, layers, num_classes=100, zero_init_residual=False, align="CONV"): 117 | super(ResNet, self).__init__() 118 | print("num_class: ", num_classes) 119 | self.inplanes = 64 120 | self.align = align 121 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, 122 | bias=False) 123 | self.bn1 = nn.BatchNorm2d(64) 124 | self.relu = nn.ReLU(inplace=True) 125 | 126 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 127 | self.layer1 = self._make_layer(block, 64, layers[0]) 128 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 129 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 130 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 131 | 132 | print("CONV for aligning") 133 | self.scala1 = ScalaNet( 134 | channel_in=64*block.expansion, 135 | channel_out=512*block.expansion, 136 | size=8 137 | ) 138 | self.scala2 = ScalaNet( 139 | channel_in=128 * block.expansion, 140 | channel_out=512 * block.expansion, 141 | size=4 142 | ) 143 | self.scala3 = ScalaNet( 144 | channel_in=256 * block.expansion, 145 | channel_out=512 * block.expansion, 146 | size=2 147 | ) 148 | self.scala4 = nn.AvgPool2d(4, 4) 149 | 150 | self.attention1 = nn.Sequential( 151 | nn.Conv2d(kernel_size=3, padding=1, stride=2, in_channels=64* block.expansion, out_channels=64* block.expansion), 152 | nn.BatchNorm2d(64* block.expansion), 153 | nn.ReLU(), 154 | nn.ConvTranspose2d(kernel_size=4, padding=1, stride=2, in_channels=64* block.expansion, out_channels=64* block.expansion), 155 | nn.BatchNorm2d(64* block.expansion), 156 | nn.Sigmoid() 157 | ) 158 | 159 | self.attention2 = nn.Sequential( 160 | nn.Conv2d(kernel_size=3, padding=1, stride=2, in_channels=128* block.expansion, out_channels=128* block.expansion), 161 | nn.BatchNorm2d(128* block.expansion), 162 | nn.ReLU(), 163 | nn.ConvTranspose2d(kernel_size=4, padding=1, stride=2, in_channels=128* block.expansion, out_channels=128* block.expansion), 164 | nn.BatchNorm2d(128* block.expansion), 165 | nn.Sigmoid() 166 | ) 167 | 168 | self.attention3 = nn.Sequential( 169 | nn.Conv2d(kernel_size=3, padding=1, stride=2, in_channels=256* block.expansion, out_channels=256* block.expansion), 170 | nn.BatchNorm2d(256* block.expansion), 171 | nn.ReLU(), 172 | nn.ConvTranspose2d(kernel_size=4, padding=1, stride=2, in_channels=256* block.expansion, out_channels=256* block.expansion), 173 | nn.BatchNorm2d(256* block.expansion), 174 | nn.Sigmoid() 175 | ) 176 | 177 | self.fc1 = nn.Linear(512 * block.expansion, num_classes) 178 | self.fc2 = nn.Linear(512 * block.expansion, num_classes) 179 | self.fc3 = nn.Linear(512 * block.expansion, num_classes) 180 | self.fc4 = nn.Linear(512 * block.expansion, num_classes) 181 | 182 | for m in self.modules(): 183 | if isinstance(m, nn.Conv2d): 184 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 185 | elif isinstance(m, nn.BatchNorm2d): 186 | nn.init.constant_(m.weight, 1) 187 | nn.init.constant_(m.bias, 0) 188 | 189 | # Zero-initialize the last BN in each residual branch, 190 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 191 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 192 | if zero_init_residual: 193 | for m in self.modules(): 194 | if isinstance(m, Bottleneck): 195 | nn.init.constant_(m.bn3.weight, 0) 196 | elif isinstance(m, BasicBlock): 197 | nn.init.constant_(m.bn2.weight, 0) 198 | 199 | def _make_layer(self, block, planes, blocks, stride=1): 200 | downsample = None 201 | if stride != 1 or self.inplanes != planes * block.expansion: 202 | downsample = nn.Sequential( 203 | conv1x1(self.inplanes, planes * block.expansion, stride), 204 | nn.BatchNorm2d(planes * block.expansion), 205 | ) 206 | 207 | layers = [] 208 | layers.append(block(self.inplanes, planes, stride, downsample)) 209 | self.inplanes = planes * block.expansion 210 | for _ in range(1, blocks): 211 | layers.append(block(self.inplanes, planes)) 212 | 213 | return nn.Sequential(*layers) 214 | 215 | def forward(self, x): 216 | feature_list = [] 217 | x = self.conv1(x) 218 | x = self.bn1(x) 219 | x = self.relu(x) 220 | x = self.layer1(x) 221 | 222 | fea1 = self.attention1(x) 223 | fea1 = fea1 * x 224 | feature_list.append(fea1) 225 | 226 | x = self.layer2(x) 227 | 228 | fea2 = self.attention2(x) 229 | fea2 = fea2 * x 230 | feature_list.append(fea2) 231 | 232 | x = self.layer3(x) 233 | 234 | fea3 = self.attention3(x) 235 | fea3 = fea3 * x 236 | feature_list.append(fea3) 237 | 238 | 239 | x = self.layer4(x) 240 | feature_list.append(x) 241 | 242 | 243 | 244 | out1_feature = self.scala1(feature_list[0]).view(x.size(0), -1) 245 | out2_feature = self.scala2(feature_list[1]).view(x.size(0), -1) 246 | out3_feature = self.scala3(feature_list[2]).view(x.size(0), -1) 247 | out4_feature = self.scala4(feature_list[3]).view(x.size(0), -1) 248 | 249 | teacher_feature = out4_feature.detach() 250 | feature_loss = ((teacher_feature - out3_feature)**2 + (teacher_feature - out2_feature)**2 +\ 251 | (teacher_feature - out1_feature)**2).sum() 252 | 253 | out1 = self.fc1(out1_feature) 254 | out2 = self.fc2(out2_feature) 255 | out3 = self.fc3(out3_feature) 256 | out4 = self.fc4(out4_feature) 257 | 258 | return [out4, out3, out2, out1], feature_loss 259 | # None is prepared for Hint Learning 260 | 261 | 262 | def resnet18(pretrained=False, **kwargs): 263 | """Constructs a ResNet-18 model. 264 | 265 | Args: 266 | pretrained (bool): If True, returns a model pre-trained on ImageNet 267 | """ 268 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 269 | if pretrained: 270 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 271 | return model 272 | 273 | 274 | def resnet34(pretrained=False, **kwargs): 275 | """Constructs a ResNet-34 model. 276 | 277 | Args: 278 | pretrained (bool): If True, returns a model pre-trained on ImageNet 279 | """ 280 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 281 | if pretrained: 282 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 283 | return model 284 | 285 | 286 | def resnet50(pretrained=False, **kwargs): 287 | """Constructs a ResNet-50 model. 288 | 289 | Args: 290 | pretrained (bool): If True, returns a model pre-trained on ImageNet 291 | """ 292 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 293 | if pretrained: 294 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 295 | return model 296 | 297 | 298 | def resnet101(pretrained=False, **kwargs): 299 | """Constructs a ResNet-101 model. 300 | 301 | Args: 302 | pretrained (bool): If True, returns a model pre-trained on ImageNet 303 | """ 304 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 305 | if pretrained: 306 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 307 | return model 308 | 309 | 310 | def resnet152(pretrained=False, **kwargs): 311 | """Constructs a ResNet-152 model. 312 | 313 | Args: 314 | pretrained (bool): If True, returns a model pre-trained on ImageNet 315 | """ 316 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 317 | if pretrained: 318 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 319 | 320 | return model 321 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torchvision 5 | import torchvision.transforms as transforms 6 | import argparse 7 | import sresnet 8 | import torch.nn.functional as F 9 | 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | torch.manual_seed(100) 12 | torch.cuda.manual_seed(100) 13 | 14 | parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training') 15 | parser.add_argument('--depth', default=18, type=int) 16 | parser.add_argument('--class_num', default=100, type=int) 17 | parser.add_argument('--epoch', default=200, type=int) 18 | parser.add_argument('--lambda_KD', default=0.5, type=float) 19 | args = parser.parse_args() 20 | print(args) 21 | 22 | 23 | def CrossEntropy(outputs, targets): 24 | log_softmax_outputs = F.log_softmax(outputs/3.0, dim=1) 25 | softmax_targets = F.softmax(targets/3.0, dim=1) 26 | return -(log_softmax_outputs * softmax_targets).sum(dim=1).mean() 27 | 28 | 29 | BATCH_SIZE = 128 30 | LR = 0.1 31 | 32 | transform_train = transforms.Compose([ 33 | transforms.RandomCrop(32, padding=4), 34 | transforms.RandomHorizontalFlip(), 35 | transforms.ToTensor(), 36 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 37 | ]) 38 | transform_test = transforms.Compose([ 39 | transforms.ToTensor(), 40 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 41 | ]) 42 | trainset, testset = None, None 43 | if args.class_num == 100: 44 | print("dataset: CIFAR100") 45 | trainset = torchvision.datasets.CIFAR100( 46 | root='/home2/lthpc/data', 47 | train=True, 48 | download=False, 49 | transform=transform_train 50 | ) 51 | testset = torchvision.datasets.CIFAR100( 52 | root='/home2/lthpc/data', 53 | train=False, 54 | download=False, 55 | transform=transform_test 56 | ) 57 | if args.class_num == 10: 58 | print("dataset: CIFAR10") 59 | trainset = torchvision.datasets.CIFAR10( 60 | root='./data', 61 | train=True, 62 | download=False, 63 | transform=transform_train 64 | ) 65 | testset = torchvision.datasets.CIFAR10( 66 | root='./data', 67 | train=False, 68 | download=False, 69 | transform=transform_test 70 | ) 71 | trainloader = torch.utils.data.DataLoader( 72 | trainset, 73 | batch_size=BATCH_SIZE, 74 | shuffle=True, 75 | num_workers=4 76 | ) 77 | testloader = torch.utils.data.DataLoader( 78 | testset, 79 | batch_size=BATCH_SIZE, 80 | shuffle=False, 81 | num_workers=4 82 | ) 83 | 84 | net = None 85 | if args.depth == 18: 86 | net = sresnet.resnet18(num_classes=args.class_num, align="CONV") 87 | print("using resnet 18") 88 | if args.depth == 50: 89 | net = sresnet.resnet50(num_classes=args.class_num, align="CONV") 90 | print("using resnet 50") 91 | if args.depth == 101: 92 | net = sresnet.resnet101(num_classes=args.class_num, align="CONV") 93 | print("using resnet 101") 94 | if args.depth == 152: 95 | net = sresnet.resnet152(num_classes=args.class_num, align="CONV") 96 | print("using resnet 152") 97 | 98 | net.to(device) 99 | criterion = nn.CrossEntropyLoss() 100 | optimizer = optim.SGD(net.parameters(), lr=LR, weight_decay=5e-4, momentum=0.9) 101 | 102 | if __name__ == "__main__": 103 | best_acc = 0 104 | print("Start Training") # 定义遍历数据集的次数 105 | with open("acc.txt", "w") as f: 106 | with open("log.txt", "w")as f2: 107 | for epoch in range(args.epoch): 108 | correct4, correct3, correct2, correct1, correct0 = 0, 0, 0, 0, 0 109 | predicted4, predicted3, predicted2, predicted1, predicted0 = 0, 0, 0, 0, 0 110 | if epoch in [75, 130, 180]: 111 | for param_group in optimizer.param_groups: 112 | param_group['lr'] /= 10 113 | net.train() 114 | sum_loss = 0.0 115 | correct = 0.0 116 | total = 0.0 117 | for i, data in enumerate(trainloader, 0): 118 | length = len(trainloader) 119 | inputs, labels = data 120 | inputs, labels = inputs.to(device), labels.to(device) 121 | outputs, feature_loss = net(inputs) 122 | 123 | ensemble = sum(outputs[:-1])/len(outputs) 124 | ensemble.detach_() 125 | ensemble.requires_grad = False 126 | 127 | # compute loss 128 | loss = torch.FloatTensor([0.]).to(device) 129 | 130 | # for deepest classifier 131 | loss += criterion(outputs[0], labels) 132 | 133 | # for soft & hard target 134 | teacher_output = outputs[0].detach() 135 | teacher_output.requires_grad = False 136 | 137 | for index in range(1, len(outputs)): 138 | loss += CrossEntropy(outputs[index], teacher_output) * args.lambda_KD * 9 139 | loss += criterion(outputs[index], labels) * (1 - args.lambda_KD) 140 | 141 | # for faeture align loss 142 | if args.lambda_KD != 0: 143 | loss += feature_loss * 5e-7 144 | 145 | optimizer.zero_grad() 146 | loss.backward() 147 | optimizer.step() 148 | 149 | total += float(labels.size(0)) 150 | sum_loss += loss.item() 151 | 152 | _0, predicted0 = torch.max(outputs[0].data, 1) 153 | _1, predicted1 = torch.max(outputs[1].data, 1) 154 | _2, predicted2 = torch.max(outputs[2].data, 1) 155 | _3, predicted3 = torch.max(outputs[3].data, 1) 156 | _4, predicted4 = torch.max(ensemble.data, 1) 157 | 158 | correct0 += float(predicted0.eq(labels.data).cpu().sum()) 159 | correct1 += float(predicted1.eq(labels.data).cpu().sum()) 160 | correct2 += float(predicted2.eq(labels.data).cpu().sum()) 161 | correct3 += float(predicted3.eq(labels.data).cpu().sum()) 162 | correct4 += float(predicted4.eq(labels.data).cpu().sum()) 163 | 164 | print('[epoch:%d, iter:%d] Loss: %.03f | Acc: 4/4: %.2f%% 3/4: %.2f%% 2/4: %.2f%% 1/4: %.2f%%' 165 | ' Ensemble: %.2f%%' % (epoch + 1, (i + 1 + epoch * length), sum_loss / (i + 1), 166 | 100 * correct0 / total, 100 * correct1 / total, 167 | 100 * correct2 / total, 100 * correct3 / total, 168 | 100 * correct4 / total)) 169 | 170 | print("Waiting Test!") 171 | with torch.no_grad(): 172 | correct4, correct3, correct2, correct1, correct0 = 0, 0, 0, 0, 0 173 | predicted4, predicted3, predicted2, predicted1, predicted0 = 0, 0, 0, 0, 0 174 | correct = 0.0 175 | total = 0.0 176 | for data in testloader: 177 | net.eval() 178 | images, labels = data 179 | images, labels = images.to(device), labels.to(device) 180 | outputs, feature_loss = net(images) 181 | ensemble = sum(outputs) / len(outputs) 182 | _0, predicted0 = torch.max(outputs[0].data, 1) 183 | _1, predicted1 = torch.max(outputs[1].data, 1) 184 | _2, predicted2 = torch.max(outputs[2].data, 1) 185 | _3, predicted3 = torch.max(outputs[3].data, 1) 186 | _4, predicted4 = torch.max(ensemble.data, 1) 187 | 188 | correct0 += float(predicted0.eq(labels.data).cpu().sum()) 189 | correct1 += float(predicted1.eq(labels.data).cpu().sum()) 190 | correct2 += float(predicted2.eq(labels.data).cpu().sum()) 191 | correct3 += float(predicted3.eq(labels.data).cpu().sum()) 192 | correct4 += float(predicted4.eq(labels.data).cpu().sum()) 193 | total += float(labels.size(0)) 194 | 195 | print('Test Set AccuracyAcc: 4/4: %.4f%% 3/4: %.4f%% 2/4: %.4f%% 1/4: %.4f%%' 196 | ' Ensemble: %.4f%%' % (100 * correct0 / total, 100 * correct1 / total, 197 | 100 * correct2 / total, 100 * correct3 / total, 198 | 100 * correct4 / total)) 199 | if correct0/total > best_acc: 200 | torch.save(net.state_dict(), "./4att/bestmodel.pth") 201 | print("model saved") 202 | best_acc = correct0/total 203 | 204 | print("Training Finished, TotalEPOCH=%d" % args.epoch) 205 | 206 | 207 | 208 | --------------------------------------------------------------------------------