├── .gitignore ├── LICENSE ├── README.md ├── base.py ├── dlb.py ├── models ├── __init__.py ├── densenet.py ├── resnet.py └── vgg.py ├── my_dataloader.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Meta-knowledge-Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Self-Distillation from the Last Mini-Batch (DLB) 2 | 3 | This is a pytorch implementation for "Self-Distillation from the Last Mini-Batch for Consistency Regularization". The paper was accepted by CVPR 2022. 4 | 5 | The paper is available at [https://arxiv.org/abs/2203.16172](https://arxiv.org/abs/2203.16172). 6 | 7 | 8 | Run `dlb.py` for the proposed self distillation method. 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /base.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.optim as optim 4 | from torch.optim.lr_scheduler import MultiStepLR 5 | import torch.nn.functional as F 6 | from my_dataloader import get_dataloader 7 | from models import model_dict 8 | import os 9 | from utils import AverageMeter, accuracy 10 | import numpy as np 11 | from datetime import datetime 12 | 13 | torch.backends.cudnn.benchmark = True 14 | torch.backends.cudnn.deterministic = True 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | "--model_names", type=str, nargs="+", default=["resnet20", "resnet20"] 19 | ) 20 | 21 | parser.add_argument("--root", type=str, default="./dataset") 22 | parser.add_argument("--num_workers", type=int, default=16) 23 | parser.add_argument( 24 | "--dataset", 25 | type=str, 26 | default="cifar100", 27 | choices=["cifar100", "cifar10", "CUB", "tinyimagenet"], 28 | help="dataset", 29 | ) 30 | parser.add_argument("--classes_num", type=int, default=100) 31 | 32 | parser.add_argument("--batch_size", type=int, default=64) 33 | parser.add_argument("--epoch", type=int, default=240) 34 | parser.add_argument("--lr", type=float, default=0.05) 35 | parser.add_argument("--momentum", type=float, default=0.9) 36 | parser.add_argument("--weight-decay", type=float, default=5e-4) 37 | parser.add_argument("--gamma", type=float, default=0.1) 38 | parser.add_argument("--milestones", type=int, nargs="+", default=[150, 180, 210]) 39 | 40 | 41 | parser.add_argument("--seed", type=int, default=1) 42 | parser.add_argument("--gpu-id", type=int, default=0) 43 | parser.add_argument("--print_freq", type=int, default=100) 44 | parser.add_argument("--aug_nums", type=int, default=2) 45 | parser.add_argument("--exp_postfix", type=str, default="base") 46 | 47 | args = parser.parse_args() 48 | args.num_branch = len(args.model_names) 49 | 50 | torch.manual_seed(args.seed) 51 | np.random.seed(args.seed) 52 | torch.cuda.manual_seed(args.seed) 53 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) 54 | 55 | exp_name = "_".join(args.model_names) + args.exp_postfix 56 | exp_path = "./baseline/{}/{}".format(args.dataset, exp_name) 57 | os.makedirs(exp_path, exist_ok=True) 58 | print(exp_path) 59 | 60 | 61 | def train_one_epoch(models, optimizers, train_loader): 62 | acc_recorder_list = [] 63 | loss_recorder_list = [] 64 | for model in models: 65 | model.train() 66 | acc_recorder_list.append(AverageMeter()) 67 | loss_recorder_list.append(AverageMeter()) 68 | 69 | for i, (imgs, label) in enumerate(train_loader): 70 | out_list = [] 71 | teacher_outs = [] 72 | # forward 73 | for model_idx, model in enumerate(models): 74 | 75 | if torch.cuda.is_available(): 76 | imgs = imgs.cuda() 77 | label = label.cuda() 78 | 79 | out = model.forward(imgs[:, model_idx, ...]) 80 | # outputs[model_idx, ...] = out 81 | out_list.append(out) 82 | 83 | for model_idx, model in enumerate(models): 84 | loss = F.cross_entropy(out_list[model_idx], label) 85 | optimizers[model_idx].zero_grad() 86 | loss.backward() 87 | optimizers[model_idx].step() 88 | 89 | loss_recorder_list[model_idx].update(loss.item(), n=imgs.size(0)) 90 | acc = accuracy(out_list[model_idx], label)[0] 91 | acc_recorder_list[model_idx].update(acc.item(), n=imgs.size(0)) 92 | 93 | losses = [recorder.avg for recorder in loss_recorder_list] 94 | acces = [recorder.avg for recorder in acc_recorder_list] 95 | return losses, acces 96 | 97 | 98 | def evaluation(models, val_loader): 99 | acc_recorder_list = [] 100 | loss_recorder_list = [] 101 | for model in models: 102 | model.eval() 103 | acc_recorder_list.append(AverageMeter()) 104 | loss_recorder_list.append(AverageMeter()) 105 | 106 | with torch.no_grad(): 107 | for img, label in val_loader: 108 | if torch.cuda.is_available(): 109 | img = img.cuda() 110 | label = label.cuda() 111 | 112 | for model_idx, model in enumerate(models): 113 | out = model(img) 114 | acc = accuracy(out, label)[0] 115 | loss = F.cross_entropy(out, label) 116 | acc_recorder_list[model_idx].update(acc.item(), img.size(0)) 117 | loss_recorder_list[model_idx].update(loss.item(), img.size(0)) 118 | losses = [recorder.avg for recorder in loss_recorder_list] 119 | acces = [recorder.avg for recorder in acc_recorder_list] 120 | return losses, acces 121 | 122 | 123 | def train(model_list, optimizer_list, train_loader, scheduler_list): 124 | best_acc = [-1 for _ in range(args.num_branch)] 125 | 126 | f_writers = [] 127 | for i in range(len(model_list)): 128 | f = open(os.path.join(exp_path, "log_{}test.txt".format(i)), "w") 129 | f_writers.append(f) 130 | 131 | for epoch in range(args.epoch): 132 | train_losses, train_acces = train_one_epoch( 133 | model_list, optimizer_list, train_loader 134 | ) 135 | val_losses, val_acces = evaluation(model_list, val_loader) 136 | 137 | for i in range(len(best_acc)): 138 | if val_acces[i] > best_acc[i]: 139 | best_acc[i] = val_acces[i] 140 | state_dict = dict( 141 | epoch=epoch + 1, model=model_list[i].state_dict(), acc=val_acces[i] 142 | ) 143 | name = os.path.join( 144 | exp_path, args.model_names[i], "ckpt", "best{}.pth".format(i) 145 | ) 146 | os.makedirs(os.path.dirname(name), exist_ok=True) 147 | torch.save(state_dict, name) 148 | 149 | scheduler_list[i].step() 150 | 151 | if (epoch + 1) % args.print_freq == 0: 152 | for j in range(len(best_acc)): 153 | msg = "epoch:{} model:{} train loss:{:.2f} acc:{:.2f} val loss{:.2f} acc:{:.2f}\n".format( 154 | epoch, 155 | args.model_names[j], 156 | train_losses[j], 157 | train_acces[j], 158 | val_losses[j], 159 | val_acces[j], 160 | ) 161 | print(msg) 162 | f_writers[j].write(msg) 163 | f_writers[j].flush() 164 | 165 | for k in range(len(best_acc)): 166 | msg_best = "model:{} best acc:{:.2f}".format(args.model_names[k], best_acc[k]) 167 | print(msg_best) 168 | f_writers[k].write(msg_best) 169 | f_writers[k].close() 170 | 171 | 172 | if __name__ == "__main__": 173 | train_loader, val_loader = get_dataloader(args) 174 | model_list = [] 175 | optimizer_list = [] 176 | scheduler_list = [] 177 | for name in args.model_names: 178 | lr = args.lr 179 | print(name) 180 | model = model_dict[name](num_classes=args.classes_num) 181 | if torch.cuda.is_available(): 182 | model = model.cuda() 183 | 184 | optimizer = optim.SGD( 185 | model.parameters(), 186 | lr=lr, 187 | momentum=args.momentum, 188 | nesterov=True, 189 | weight_decay=args.weight_decay, 190 | ) 191 | scheduler = MultiStepLR(optimizer, args.milestones, args.gamma) 192 | model_list.append(model) 193 | optimizer_list.append(optimizer) 194 | scheduler_list.append(scheduler) 195 | 196 | train(model_list, optimizer_list, train_loader, scheduler_list) 197 | -------------------------------------------------------------------------------- /dlb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.optim as optim 4 | from torch.optim.lr_scheduler import MultiStepLR 5 | import torch.nn.functional as F 6 | from my_dataloader import get_dataloader 7 | from models import model_dict 8 | import os 9 | from utils import AverageMeter, accuracy 10 | import numpy as np 11 | from datetime import datetime 12 | import random 13 | 14 | torch.backends.cudnn.benchmark = True 15 | torch.backends.cudnn.deterministic = True 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--model_names", type=str, default="vgg16") 19 | 20 | 21 | parser.add_argument("--root", type=str, default="./dataset") 22 | parser.add_argument("--num_workers", type=int, default=16) 23 | parser.add_argument("--classes_num", type=int, default=100) 24 | parser.add_argument( 25 | "--dataset", 26 | type=str, 27 | default="cifar100", 28 | choices=["cifar100", "cifar10", "CUB", "tinyimagenet"], 29 | help="dataset", 30 | ) 31 | 32 | parser.add_argument("--T", type=float) 33 | parser.add_argument("--alpha", type=float) 34 | parser.add_argument("--batch_size", type=int) 35 | parser.add_argument("--epoch", type=int) 36 | parser.add_argument("--lr", type=float) 37 | parser.add_argument("--milestones", type=int, nargs="+") 38 | 39 | parser.add_argument("--momentum", type=float, default=0.9) 40 | parser.add_argument("--weight-decay", type=float, default=5e-4) 41 | parser.add_argument("--gamma", type=float, default=0.1) 42 | 43 | parser.add_argument("--seed", type=int, default=95) 44 | parser.add_argument("--gpu-id", type=int, default=0) 45 | parser.add_argument("--print_freq", type=int, default=100) 46 | parser.add_argument("--aug_nums", type=int, default=2) # 47 | parser.add_argument("--exp_postfix", type=str, default="TP3_0.5") # 48 | 49 | args = parser.parse_args() 50 | args.num_branch = len(args.model_names) 51 | 52 | torch.manual_seed(args.seed) 53 | np.random.seed(args.seed) 54 | torch.cuda.manual_seed(args.seed) 55 | random.seed(args.seed) 56 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id) 57 | 58 | exp_name = "_".join(args.model_names) + args.exp_postfix 59 | exp_path = "./dlb/{}/{}".format(args.dataset, exp_name) 60 | os.makedirs(exp_path, exist_ok=True) 61 | print(exp_path) 62 | 63 | 64 | def train_one_epoch(model, optimizer, train_loader, alpha, pre_data, pre_out): 65 | model.train() 66 | acc_recorder = AverageMeter() 67 | loss_recorder = AverageMeter() 68 | 69 | for i, data in enumerate(train_loader): 70 | 71 | imgs, label = data 72 | if torch.cuda.is_available(): 73 | imgs = imgs.cuda() 74 | label = label.cuda() 75 | out = model.forward(imgs[:, 0, ...]) 76 | 77 | if pre_data != None: 78 | pre_images, pre_label = pre_data 79 | if torch.cuda.is_available(): 80 | pre_images = pre_images.cuda() 81 | pre_label = pre_label.cuda() 82 | out_pre = model.forward(pre_images[:, 1, ...]) 83 | ce_loss = F.cross_entropy( 84 | torch.cat((out_pre, out), dim=0), torch.cat((pre_label, label), dim=0) 85 | ) # 86 | dml_loss = ( 87 | F.kl_div( 88 | F.log_softmax(out_pre / args.T, dim=1), 89 | F.softmax(pre_out.detach() / args.T, dim=1), # detach 90 | reduction="batchmean", 91 | ) 92 | * args.T 93 | * args.T 94 | ) 95 | loss = ce_loss + alpha * dml_loss 96 | else: 97 | loss = F.cross_entropy(out, label) 98 | 99 | loss_recorder.update(loss.item(), n=imgs.size(0)) 100 | acc = accuracy(out, label)[0] 101 | acc_recorder.update(acc.item(), n=imgs.size(0)) 102 | 103 | pre_data = data 104 | pre_out = out 105 | 106 | optimizer.zero_grad() 107 | loss.backward() 108 | optimizer.step() 109 | 110 | losses = loss_recorder.avg 111 | acces = acc_recorder.avg 112 | 113 | return losses, acces, pre_data, pre_out 114 | 115 | 116 | def evaluation(model, val_loader): 117 | model.eval() 118 | acc_recorder = AverageMeter() 119 | loss_recorder = AverageMeter() 120 | 121 | with torch.no_grad(): 122 | for img, label in val_loader: 123 | if torch.cuda.is_available(): 124 | img = img.cuda() 125 | label = label.cuda() 126 | 127 | out = model(img) 128 | acc = accuracy(out, label)[0] 129 | loss = F.cross_entropy(out, label) 130 | acc_recorder.update(acc.item(), img.size(0)) 131 | loss_recorder.update(loss.item(), img.size(0)) 132 | losses = loss_recorder.avg 133 | acces = acc_recorder.avg 134 | return losses, acces 135 | 136 | 137 | def train(model, optimizer, train_loader, scheduler): 138 | best_acc = -1 139 | 140 | f = open(os.path.join(exp_path, "log_test.txt"), "w") 141 | 142 | pre_data, pre_out = None, None 143 | 144 | for epoch in range(args.epoch): 145 | alpha = args.alpha 146 | train_losses, train_acces, pre_data, pre_out = train_one_epoch( 147 | model, optimizer, train_loader, alpha, pre_data, pre_out 148 | ) 149 | val_losses, val_acces = evaluation(model, val_loader) 150 | 151 | if val_acces > best_acc: 152 | best_acc = val_acces 153 | state_dict = dict(epoch=epoch + 1, model=model.state_dict(), acc=val_acces) 154 | name = os.path.join(exp_path, args.model_names, "ckpt", "best.pth") 155 | os.makedirs(os.path.dirname(name), exist_ok=True) 156 | torch.save(state_dict, name) 157 | 158 | scheduler.step() 159 | 160 | if (epoch + 1) % args.print_freq == 0: 161 | msg = "epoch:{} model:{} train loss:{:.2f} acc:{:.2f} val loss{:.2f} acc:{:.2f}\n".format( 162 | epoch, 163 | args.model_names, 164 | train_losses, 165 | train_acces, 166 | val_losses, 167 | val_acces, 168 | ) 169 | print(msg) 170 | f.write(msg) 171 | f.flush() 172 | 173 | msg_best = "model:{} best acc:{:.2f}".format(args.model_names, best_acc) 174 | print(msg_best) 175 | f.write(msg_best) 176 | f.close() 177 | 178 | 179 | if __name__ == "__main__": 180 | train_loader, val_loader = get_dataloader(args) 181 | lr = args.lr 182 | model = model_dict[args.model_names](num_classes=args.classes_num) 183 | if torch.cuda.is_available(): 184 | model = model.cuda() 185 | 186 | optimizer = optim.SGD( 187 | model.parameters(), 188 | lr=lr, 189 | momentum=args.momentum, 190 | nesterov=True, 191 | weight_decay=args.weight_decay, 192 | ) 193 | scheduler = MultiStepLR(optimizer, args.milestones, args.gamma) 194 | 195 | train(model, optimizer, train_loader, scheduler) 196 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import vgg16, vgg19 2 | from .resnet import resnet32, resnet110, wide_resnet20_8 3 | from .densenet import densenetd40k12, densenetd100k12, densenetd100k40, densenetd190k12 4 | 5 | model_dict = { 6 | "vgg16": vgg16, 7 | "vgg19": vgg19, 8 | "resnet32": resnet32, 9 | "resnet110": resnet110, 10 | "wide_resnet20_8": wide_resnet20_8, 11 | "densenetd40k12": densenetd40k12, 12 | "densenetd100k12": densenetd100k12, 13 | "densenetd100k40": densenetd100k40, 14 | "densenetd190k12": densenetd190k12, 15 | } 16 | -------------------------------------------------------------------------------- /models/densenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | DenseNet for CIFAR-10/100 Dataset. 3 | 4 | Reference: 5 | 1. https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py 6 | 2. https://github.com/liuzhuang13/DenseNet 7 | 3. https://github.com/gpleiss/efficient_densenet_pytorch 8 | 4. Gao Huang, zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger 9 | Densely Connetcted Convolutional Networks. https://arxiv.org/abs/1608.06993 10 | 11 | """ 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.utils.checkpoint as cp 16 | from collections import OrderedDict 17 | 18 | __all__ = [ 19 | "DenseNet", 20 | "densenetd40k12", 21 | "densenetd100k12", 22 | "densenetd100k40", 23 | "densenetd190k12", 24 | ] 25 | 26 | 27 | def _bn_function_factory(norm, relu, conv): 28 | def bn_function(*inputs): 29 | concated_features = torch.cat(inputs, 1) 30 | bottleneck_output = conv(relu(norm(concated_features))) 31 | return bottleneck_output 32 | 33 | return bn_function 34 | 35 | 36 | class _DenseLayer(nn.Module): 37 | def __init__( 38 | self, num_input_features, growth_rate, bn_size, drop_rate, efficient=False 39 | ): 40 | super(_DenseLayer, self).__init__() 41 | self.add_module("norm1", nn.BatchNorm2d(num_input_features)), 42 | self.add_module("relu1", nn.ReLU(inplace=True)), 43 | self.add_module( 44 | "conv1", 45 | nn.Conv2d( 46 | num_input_features, 47 | bn_size * growth_rate, 48 | kernel_size=1, 49 | stride=1, 50 | bias=False, 51 | ), 52 | ), 53 | self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate)), 54 | self.add_module("relu2", nn.ReLU(inplace=True)), 55 | self.add_module( 56 | "conv2", 57 | nn.Conv2d( 58 | bn_size * growth_rate, 59 | growth_rate, 60 | kernel_size=3, 61 | stride=1, 62 | padding=1, 63 | bias=False, 64 | ), 65 | ), 66 | self.drop_rate = drop_rate 67 | self.efficient = efficient 68 | 69 | def forward(self, *prev_features): 70 | bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) 71 | if self.efficient and any( 72 | prev_feature.requires_grad for prev_feature in prev_features 73 | ): 74 | bottleneck_output = cp.checkpoint(bn_function, *prev_features) 75 | else: 76 | bottleneck_output = bn_function(*prev_features) 77 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) 78 | if self.drop_rate > 0: 79 | new_features = F.dropout( 80 | new_features, p=self.drop_rate, training=self.training 81 | ) 82 | return new_features 83 | 84 | 85 | class _Transition(nn.Sequential): 86 | def __init__(self, num_input_features, num_output_features): 87 | super(_Transition, self).__init__() 88 | self.add_module("norm", nn.BatchNorm2d(num_input_features)) 89 | self.add_module("relu", nn.ReLU(inplace=True)) 90 | self.add_module( 91 | "conv", 92 | nn.Conv2d( 93 | num_input_features, 94 | num_output_features, 95 | kernel_size=1, 96 | stride=1, 97 | bias=False, 98 | ), 99 | ) 100 | self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2)) 101 | 102 | 103 | class _DenseBlock(nn.Module): 104 | def __init__( 105 | self, 106 | num_layers, 107 | num_input_features, 108 | bn_size, 109 | growth_rate, 110 | drop_rate, 111 | efficient=False, 112 | ): 113 | super(_DenseBlock, self).__init__() 114 | for i in range(num_layers): 115 | layer = _DenseLayer( 116 | num_input_features + i * growth_rate, 117 | growth_rate=growth_rate, 118 | bn_size=bn_size, 119 | drop_rate=drop_rate, 120 | efficient=efficient, 121 | ) 122 | self.add_module("denselayer%d" % (i + 1), layer) 123 | 124 | def forward(self, init_features): 125 | features = [init_features] 126 | for name, layer in self.named_children(): 127 | new_features = layer(*features) 128 | features.append(new_features) 129 | return torch.cat(features, 1) 130 | 131 | 132 | class DenseNet(nn.Module): 133 | r"""Densenet-BC model class, based on 134 | `"Densely Connected Convolutional Networks" ` 135 | Args: 136 | growth_rate (int) - how many filters to add each layer (`k` in paper) 137 | block_config (list of 3 or 4 ints) - how many layers in each pooling block 138 | num_init_features (int) - the number of filters to learn in the first convolution layer 139 | bn_size (int) - multiplicative factor for number of bottle neck layers 140 | (i.e. bn_size * k features in the bottleneck layer) 141 | drop_rate (float) - dropout rate after each dense layer 142 | num_classes (int) - number of classification classes 143 | small_inputs (bool) - set to True if images are 32x32. Otherwise assumes images are larger. 144 | efficient (bool) - set to True to use checkpointing. Much more memory efficient, but slower. 145 | """ 146 | 147 | def __init__( 148 | self, 149 | growth_rate=12, 150 | block_config=[16, 16, 16], 151 | compression=0.5, 152 | num_init_features=24, 153 | bn_size=4, 154 | drop_rate=0, 155 | num_classes=10, 156 | small_inputs=True, 157 | efficient=False, 158 | KD=False, 159 | ): 160 | 161 | super(DenseNet, self).__init__() 162 | assert 0 < compression <= 1, "compression of densenet should be between 0 and 1" 163 | self.avgpool_size = 8 if small_inputs else 7 164 | self.KD = KD 165 | # First convolution 166 | if small_inputs: 167 | self.features = nn.Sequential( 168 | OrderedDict( 169 | [ 170 | ( 171 | "conv0", 172 | nn.Conv2d( 173 | 3, 174 | num_init_features, 175 | kernel_size=3, 176 | stride=1, 177 | padding=1, 178 | bias=False, 179 | ), 180 | ), 181 | ] 182 | ) 183 | ) 184 | else: 185 | self.features = nn.Sequential( 186 | OrderedDict( 187 | [ 188 | ( 189 | "conv0", 190 | nn.Conv2d( 191 | 3, 192 | num_init_features, 193 | kernel_size=7, 194 | stride=2, 195 | padding=3, 196 | bias=False, 197 | ), 198 | ), 199 | ] 200 | ) 201 | ) 202 | self.features.add_module("norm0", nn.BatchNorm2d(num_init_features)) 203 | self.features.add_module("relu0", nn.ReLU(inplace=True)) 204 | self.features.add_module( 205 | "pool0", 206 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False), 207 | ) 208 | 209 | # Each denseblock 210 | num_features = num_init_features 211 | for i, num_layers in enumerate(block_config): 212 | block = _DenseBlock( 213 | num_layers=num_layers, 214 | num_input_features=num_features, 215 | bn_size=bn_size, 216 | growth_rate=growth_rate, 217 | drop_rate=drop_rate, 218 | efficient=efficient, 219 | ) 220 | self.features.add_module("denseblock%d" % (i + 1), block) 221 | num_features = num_features + num_layers * growth_rate 222 | if i != len(block_config) - 1: 223 | trans = _Transition( 224 | num_input_features=num_features, 225 | num_output_features=int(num_features * compression), 226 | ) 227 | self.features.add_module("transition%d" % (i + 1), trans) 228 | num_features = int(num_features * compression) 229 | 230 | # Final batch norm 231 | self.features.add_module("norm_final", nn.BatchNorm2d(num_features)) 232 | 233 | # Linear layer 234 | self.classifier = nn.Linear(num_features, num_classes) 235 | # Initialization 236 | for m in self.modules(): 237 | if isinstance(m, nn.Conv2d): 238 | nn.init.kaiming_normal_(m.weight) 239 | elif isinstance(m, nn.BatchNorm2d): 240 | nn.init.constant_(m.weight, 1) 241 | nn.init.constant_(m.bias, 0) 242 | elif isinstance(m, nn.Linear): 243 | nn.init.constant_(m.bias, 0) 244 | 245 | def forward(self, x): 246 | features = self.features(x) 247 | # D40K12 B x 132 x 8 x 8 248 | # D100K12 B x 342 x 8 x 8 249 | # D100K40 B x 1126 x 8 x 8 250 | x = F.relu(features, inplace=True) 251 | x_f = F.avg_pool2d(x, kernel_size=self.avgpool_size).view( 252 | features.size(0), -1 253 | ) # B x 132 254 | x = self.classifier(x_f) 255 | if self.KD == True: 256 | return x_f, x 257 | else: 258 | return x 259 | 260 | 261 | def densenetd40k12(pretrained=False, path=None, **kwargs): 262 | """ 263 | Constructs a densenetD40K12 model. 264 | 265 | Args: 266 | pretrained (bool): If True, returns a model pre-trained. 267 | """ 268 | 269 | model = DenseNet(growth_rate=12, block_config=[6, 6, 6], **kwargs) 270 | if pretrained: 271 | model.load_state_dict((torch.load(path))["state_dict"]) 272 | return model 273 | 274 | 275 | def densenetd100k12(pretrained=False, path=None, **kwargs): 276 | """ 277 | Constructs a densenetD100K12 model. 278 | 279 | Args: 280 | pretrained (bool): If True, returns a model pre-trained. 281 | """ 282 | 283 | model = DenseNet(growth_rate=12, block_config=[16, 16, 16], **kwargs) 284 | if pretrained: 285 | model.load_state_dict((torch.load(path))["state_dict"]) 286 | return model 287 | 288 | 289 | def densenetd190k12(pretrained=False, path=None, **kwargs): 290 | """ 291 | Constructs a densenetD190K12 model. 292 | 293 | Args: 294 | pretrained (bool): If True, returns a model pre-trained. 295 | """ 296 | 297 | model = DenseNet(growth_rate=12, block_config=[31, 31, 31], **kwargs) 298 | if pretrained: 299 | model.load_state_dict((torch.load(path))["state_dict"]) 300 | return model 301 | 302 | 303 | def densenetd100k40(pretrained=False, path=None, **kwargs): 304 | """ 305 | Constructs a densenetD100K40 model. 306 | 307 | Args: 308 | pretrained (bool): If True, returns a model pre-trained on ImageNet 309 | """ 310 | 311 | model = DenseNet(growth_rate=40, block_config=[16, 16, 16], **kwargs) 312 | if pretrained: 313 | model.load_state_dict((torch.load(path))["state_dict"]) 314 | return model 315 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | ResNet for CIFAR-10/100 Dataset. 3 | 4 | Reference: 5 | 1. https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 6 | 2. https://github.com/facebook/fb.resnet.torch/blob/master/models/resnet.lua 7 | 3. Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 8 | Deep Residual Learning for Image Recognition. https://arxiv.org/abs/1512.03385 9 | 10 | """ 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | __all__ = ["ResNet", "resnet32", "resnet110", "wide_resnet20_8"] 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d( 21 | in_planes, 22 | out_planes, 23 | kernel_size=3, 24 | stride=stride, 25 | padding=dilation, 26 | groups=groups, 27 | bias=False, 28 | dilation=dilation, 29 | ) 30 | 31 | 32 | def conv1x1(in_planes, out_planes, stride=1): 33 | """1x1 convolution""" 34 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 35 | 36 | 37 | class BasicBlock(nn.Module): 38 | expansion = 1 39 | 40 | def __init__( 41 | self, 42 | inplanes, 43 | planes, 44 | stride=1, 45 | downsample=None, 46 | groups=1, 47 | base_width=64, 48 | dilation=1, 49 | norm_layer=None, 50 | ): 51 | super(BasicBlock, self).__init__() 52 | if norm_layer is None: 53 | norm_layer = nn.BatchNorm2d 54 | if groups != 1 or base_width != 64: 55 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 56 | if dilation > 1: 57 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 58 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 59 | self.conv1 = conv3x3(inplanes, planes, stride) 60 | self.bn1 = norm_layer(planes) 61 | self.relu = nn.ReLU(inplace=True) 62 | self.conv2 = conv3x3(planes, planes) 63 | self.bn2 = norm_layer(planes) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x): 68 | identity = x 69 | 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.bn2(out) 76 | 77 | if self.downsample is not None: 78 | identity = self.downsample(x) 79 | 80 | out += identity 81 | out = self.relu(out) 82 | 83 | return out 84 | 85 | 86 | class Bottleneck(nn.Module): 87 | expansion = 4 88 | 89 | def __init__( 90 | self, 91 | inplanes, 92 | planes, 93 | stride=1, 94 | downsample=None, 95 | groups=1, 96 | base_width=64, 97 | dilation=1, 98 | norm_layer=None, 99 | ): 100 | super(Bottleneck, self).__init__() 101 | if norm_layer is None: 102 | norm_layer = nn.BatchNorm2d 103 | width = int(planes * (base_width / 64.0)) * groups 104 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 105 | self.conv1 = conv1x1(inplanes, width) 106 | self.bn1 = norm_layer(width) 107 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 108 | self.bn2 = norm_layer(width) 109 | self.conv3 = conv1x1(width, planes * self.expansion) 110 | self.bn3 = norm_layer(planes * self.expansion) 111 | self.relu = nn.ReLU(inplace=True) 112 | self.downsample = downsample 113 | self.stride = stride 114 | 115 | def forward(self, x): 116 | identity = x 117 | 118 | out = self.conv1(x) 119 | out = self.bn1(out) 120 | out = self.relu(out) 121 | 122 | out = self.conv2(out) 123 | out = self.bn2(out) 124 | out = self.relu(out) 125 | 126 | out = self.conv3(out) 127 | out = self.bn3(out) 128 | 129 | if self.downsample is not None: 130 | identity = self.downsample(x) 131 | 132 | out += identity 133 | out = self.relu(out) 134 | 135 | return out 136 | 137 | 138 | class ResNet(nn.Module): 139 | def __init__( 140 | self, 141 | block, 142 | layers, 143 | num_classes=10, 144 | zero_init_residual=False, 145 | groups=1, 146 | width_per_group=64, 147 | replace_stride_with_dilation=None, 148 | norm_layer=None, 149 | KD=False, 150 | ): 151 | super(ResNet, self).__init__() 152 | if norm_layer is None: 153 | norm_layer = nn.BatchNorm2d 154 | self._norm_layer = norm_layer 155 | 156 | self.inplanes = 16 157 | self.dilation = 1 158 | if replace_stride_with_dilation is None: 159 | # each element in the tuple indicates if we should replace 160 | # the 2x2 stride with a dilated convolution instead 161 | replace_stride_with_dilation = [False, False, False] 162 | if len(replace_stride_with_dilation) != 3: 163 | raise ValueError( 164 | "replace_stride_with_dilation should be None " 165 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation) 166 | ) 167 | 168 | self.groups = groups 169 | self.base_width = width_per_group 170 | self.conv1 = nn.Conv2d( 171 | 3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False 172 | ) 173 | self.bn1 = nn.BatchNorm2d(self.inplanes) 174 | self.relu = nn.ReLU(inplace=True) 175 | # self.maxpool = nn.MaxPool2d() 176 | self.layer1 = self._make_layer(block, 16, layers[0]) 177 | self.layer2 = self._make_layer(block, 32, layers[1], stride=2) 178 | self.layer3 = self._make_layer(block, 64, layers[2], stride=2) 179 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 180 | self.fc = nn.Linear(64 * block.expansion, num_classes) 181 | self.KD = KD 182 | for m in self.modules(): 183 | if isinstance(m, nn.Conv2d): 184 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 185 | elif isinstance(m, nn.BatchNorm2d): 186 | nn.init.constant_(m.weight, 1) 187 | nn.init.constant_(m.bias, 0) 188 | # Zero-initialize the last BN in each residual branch, 189 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 190 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 191 | if zero_init_residual: 192 | for m in self.modules(): 193 | if isinstance(m, Bottleneck): 194 | nn.init.constant_(m.bn3.weight, 0) 195 | elif isinstance(m, BasicBlock): 196 | nn.init.constant_(m.bn2.weight, 0) 197 | 198 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 199 | norm_layer = self._norm_layer 200 | downsample = None 201 | previous_dilation = self.dilation 202 | if dilate: 203 | self.dilation *= stride 204 | stride = 1 205 | if stride != 1 or self.inplanes != planes * block.expansion: 206 | downsample = nn.Sequential( 207 | conv1x1(self.inplanes, planes * block.expansion, stride), 208 | norm_layer(planes * block.expansion), 209 | ) 210 | 211 | layers = [] 212 | layers.append( 213 | block( 214 | self.inplanes, 215 | planes, 216 | stride, 217 | downsample, 218 | self.groups, 219 | self.base_width, 220 | previous_dilation, 221 | norm_layer, 222 | ) 223 | ) 224 | self.inplanes = planes * block.expansion 225 | for _ in range(1, blocks): 226 | layers.append( 227 | block( 228 | self.inplanes, 229 | planes, 230 | groups=self.groups, 231 | base_width=self.base_width, 232 | dilation=self.dilation, 233 | norm_layer=norm_layer, 234 | ) 235 | ) 236 | 237 | return nn.Sequential(*layers) 238 | 239 | def forward(self, x): 240 | x = self.conv1(x) 241 | x = self.bn1(x) 242 | x = self.relu(x) # B x 16 x 32 x 32 243 | 244 | x = self.layer1(x) # B x 16 x 32 x 32 245 | x = self.layer2(x) # B x 32 x 16 x 16 246 | x = self.layer3(x) # B x 64 x 8 x 8 247 | 248 | x = self.avgpool(x) # B x 64 x 1 x 1 249 | x_f = x.view(x.size(0), -1) # B x 64 250 | x = self.fc(x_f) # B x num_classes 251 | if self.KD == True: 252 | return x_f, x 253 | else: 254 | return x 255 | 256 | 257 | def resnet32(pretrained=False, path=None, **kwargs): 258 | """ 259 | Constructs a ResNet-32 model. 260 | 261 | Args: 262 | pretrained (bool): If True, returns a model pre-trained. 263 | """ 264 | 265 | model = ResNet(BasicBlock, [5, 5, 5], **kwargs) 266 | if pretrained: 267 | model.load_state_dict((torch.load(path))["state_dict"]) 268 | return model 269 | 270 | 271 | def resnet110(pretrained=False, path=None, **kwargs): 272 | """ 273 | Constructs a ResNet-110 model. 274 | 275 | Args: 276 | pretrained (bool): If True, returns a model pre-trained. 277 | """ 278 | 279 | model = ResNet(Bottleneck, [12, 12, 12], **kwargs) 280 | if pretrained: 281 | model.load_state_dict((torch.load(path))["state_dict"]) 282 | return model 283 | 284 | 285 | def wide_resnet20_8(pretrained=False, path=None, **kwargs): 286 | 287 | """Constructs a Wide ResNet-28-10 model. 288 | The model is the same as ResNet except for the bottleneck number of channels 289 | which is twice larger in every block. The number of channels in outer 1x1 290 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 291 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 292 | Args: 293 | pretrained (bool): If True, returns a model pre-trained. 294 | """ 295 | 296 | model = ResNet(Bottleneck, [2, 2, 2], width_per_group=64 * 8, **kwargs) 297 | if pretrained: 298 | model.load_state_dict((torch.load(path))["state_dict"]) 299 | return model 300 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | """ 2 | VGG16 for CIFAR-10/100 Dataset. 3 | 4 | Reference: 5 | 1. https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 6 | 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | __all__ = ["vgg16", "vgg19"] 13 | 14 | # cfg = { 15 | # 16: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 16 | # 19: [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 17 | # } 18 | 19 | 20 | class VGG(nn.Module): 21 | def __init__(self, num_classes=10, depth=16, dropout=0.0, KD=False): 22 | super(VGG, self).__init__() 23 | self.KD = KD 24 | self.inplances = 64 25 | self.conv1 = nn.Conv2d(3, self.inplances, kernel_size=3, padding=1) 26 | self.bn1 = nn.BatchNorm2d(self.inplances) 27 | self.conv2 = nn.Conv2d(self.inplances, self.inplances, kernel_size=3, padding=1) 28 | self.bn2 = nn.BatchNorm2d(self.inplances) 29 | self.relu = nn.ReLU(True) 30 | self.layer1 = self._make_layers(128, 2) 31 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 32 | 33 | if depth == 16: 34 | num_layer = 3 35 | elif depth == 19: 36 | num_layer = 4 37 | 38 | self.layer2 = self._make_layers(256, num_layer) 39 | self.layer3 = self._make_layers(512, num_layer) 40 | self.layer4 = self._make_layers(512, num_layer) 41 | 42 | self.classifier = nn.Sequential( 43 | nn.Linear(512, 512), 44 | nn.ReLU(True), 45 | nn.Dropout(p=dropout), 46 | nn.Linear(512, 512), 47 | nn.ReLU(True), 48 | nn.Dropout(p=dropout), 49 | nn.Linear(512, num_classes), 50 | ) 51 | for m in self.modules(): 52 | if isinstance(m, nn.Conv2d): 53 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 54 | if m.bias is not None: 55 | nn.init.constant_(m.bias, 0) 56 | elif isinstance(m, nn.BatchNorm2d): 57 | nn.init.constant_(m.weight, 1) 58 | nn.init.constant_(m.bias, 0) 59 | elif isinstance(m, nn.Linear): 60 | nn.init.normal_(m.weight, 0, 0.01) 61 | nn.init.constant_(m.bias, 0) 62 | 63 | def _make_layers(self, input, num_layer): 64 | layers = [] 65 | for i in range(num_layer): 66 | conv2d = nn.Conv2d(self.inplances, input, kernel_size=3, padding=1) 67 | layers += [conv2d, nn.BatchNorm2d(input), nn.ReLU(inplace=True)] 68 | self.inplances = input 69 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 70 | return nn.Sequential(*layers) 71 | 72 | def forward(self, x): 73 | 74 | x = self.conv1(x) 75 | x = self.bn1(x) 76 | x = self.relu(x) 77 | 78 | x = self.conv2(x) 79 | x = self.bn2(x) 80 | x = self.relu(x) 81 | x = self.maxpool(x) 82 | 83 | x = self.layer1(x) 84 | x = self.layer2(x) 85 | x = self.layer3(x) 86 | xs = self.layer4(x) 87 | 88 | x_f = xs.view(xs.size(0), -1) 89 | x = self.classifier(x_f) 90 | if self.KD: 91 | return x_f, x 92 | else: 93 | return x 94 | 95 | 96 | def vgg16(pretrained=False, path=None, **kwargs): 97 | """ 98 | Constructs a VGG16 model. 99 | 100 | Args: 101 | pretrained (bool): If True, returns a model pre-trained. 102 | """ 103 | model = VGG(depth=16, **kwargs) 104 | if pretrained: 105 | model.load_state_dict((torch.load(path))["state_dict"]) 106 | return model 107 | 108 | 109 | def vgg19(pretrained=False, path=None, **kwargs): 110 | """ 111 | Constructs a VGG19 model. 112 | 113 | Args: 114 | pretrained (bool): If True, returns a model pre-trained. 115 | """ 116 | model = VGG(depth=19, **kwargs) 117 | if pretrained: 118 | model.load_state_dict((torch.load(path))["state_dict"]) 119 | return model 120 | -------------------------------------------------------------------------------- /my_dataloader.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import CIFAR100 2 | from torch.utils.data import DataLoader 3 | from torchvision.transforms import Compose, transforms 4 | from torchvision.datasets import ImageFolder 5 | from PIL import Image 6 | import os 7 | import os.path 8 | import numpy as np 9 | import sys 10 | 11 | import pickle 12 | import torch 13 | import torch.utils.data as data 14 | 15 | from itertools import permutations 16 | 17 | 18 | def set_seed(seed): 19 | torch.manual_seed(seed) 20 | np.random.seed(seed) 21 | 22 | 23 | class VisionDataset(data.Dataset): 24 | _repr_indent = 4 25 | 26 | def __init__(self, root, transforms=None, transform=None, target_transform=None): 27 | if isinstance(root, torch._six.string_classes): 28 | root = os.path.expanduser(root) 29 | self.root = root 30 | 31 | has_transforms = transforms is not None 32 | has_separate_transform = transform is not None or target_transform is not None 33 | if has_transforms and has_separate_transform: 34 | raise ValueError( 35 | "Only transforms or transform/target_transform can " 36 | "be passed as argument" 37 | ) 38 | 39 | self.transform = transform 40 | self.target_transform = target_transform 41 | 42 | if has_separate_transform: 43 | transforms = StandardTransform(transform, target_transform) 44 | self.transforms = transforms 45 | 46 | def __getitem__(self, index): 47 | raise NotImplementedError 48 | 49 | def __len__(self): 50 | raise NotImplementedError 51 | 52 | def __repr__(self): 53 | head = "Dataset " + self.__class__.__name__ 54 | body = ["Number of datapoints: {}".format(self.__len__())] 55 | if self.root is not None: 56 | body.append("Root location: {}".format(self.root)) 57 | body += self.extra_repr().splitlines() 58 | if self.transforms is not None: 59 | body += [repr(self.transforms)] 60 | lines = [head] + [" " * self._repr_indent + line for line in body] 61 | return "\n".join(lines) 62 | 63 | def _format_transform_repr(self, transform, head): 64 | lines = transform.__repr__().splitlines() 65 | return ["{}{}".format(head, lines[0])] + [ 66 | "{}{}".format(" " * len(head), line) for line in lines[1:] 67 | ] 68 | 69 | def extra_repr(self): 70 | return "" 71 | 72 | 73 | class CIFAR10(VisionDataset): 74 | base_folder = "cifar-10-batches-py" 75 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 76 | filename = "cifar-10-python.tar.gz" 77 | tgz_md5 = "c58f30108f718f92721af3b95e74349a" 78 | train_list = [ 79 | ["data_batch_1", "c99cafc152244af753f735de768cd75f"], 80 | ["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"], 81 | ["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"], 82 | ["data_batch_4", "634d18415352ddfa80567beed471001a"], 83 | ["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"], 84 | ] 85 | 86 | test_list = [ 87 | ["test_batch", "40351d587109b95175f43aff81a1287e"], 88 | ] 89 | meta = { 90 | "filename": "batches.meta", 91 | "key": "label_names", 92 | "md5": "5ff9c542aee3614f3951f8cda6e48888", 93 | } 94 | 95 | def __init__( 96 | self, root, train=True, transform=None, download=False, transform_list=None 97 | ): 98 | 99 | super(CIFAR10, self).__init__(root) 100 | self.transform = transform 101 | self.transform_list = transform_list 102 | self.train = train # training set or test set 103 | 104 | if download: 105 | raise ValueError("cannot download.") 106 | exit() 107 | 108 | if self.train: 109 | downloaded_list = self.train_list 110 | else: 111 | downloaded_list = self.test_list 112 | 113 | self.data = [] 114 | self.targets = [] 115 | 116 | 117 | for file_name, checksum in downloaded_list: 118 | file_path = os.path.join(self.root, self.base_folder, file_name) 119 | with open(file_path, "rb") as f: 120 | if sys.version_info[0] == 2: 121 | entry = pickle.load(f) 122 | else: 123 | entry = pickle.load(f, encoding="latin1") 124 | self.data.append(entry["data"]) 125 | if "labels" in entry: 126 | self.targets.extend(entry["labels"]) 127 | else: 128 | self.targets.extend(entry["fine_labels"]) 129 | 130 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 131 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 132 | 133 | self._load_meta() 134 | 135 | def _load_meta(self): 136 | path = os.path.join(self.root, self.base_folder, self.meta["filename"]) 137 | # if not check_integrity(path, self.meta['md5']): 138 | # raise RuntimeError('Dataset metadata file not found or corrupted.' + 139 | # ' You can use download=True to download it') 140 | with open(path, "rb") as infile: 141 | if sys.version_info[0] == 2: 142 | data = pickle.load(infile) 143 | else: 144 | data = pickle.load(infile, encoding="latin1") 145 | self.classes = data[self.meta["key"]] 146 | self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} 147 | 148 | def __getitem__(self, index): 149 | 150 | img, target = self.data[index], self.targets[index] 151 | 152 | if self.transform_list is not None: 153 | img_transformed = [] 154 | for transform in self.transform_list: 155 | img_transformed.append(transform(Image.fromarray(img.copy()))) 156 | img = torch.stack(img_transformed) 157 | else: 158 | img = self.transform(Image.fromarray(img)) 159 | return img, target 160 | 161 | def __len__(self): 162 | return len(self.data) 163 | 164 | def _check_integrity(self): 165 | root = self.root 166 | for fentry in self.train_list + self.test_list: 167 | filename, md5 = fentry[0], fentry[1] 168 | fpath = os.path.join(root, self.base_folder, filename) 169 | if not check_integrity(fpath, md5): 170 | return False 171 | return True 172 | 173 | def download(self): 174 | import tarfile 175 | 176 | if self._check_integrity(): 177 | print("Files already downloaded and verified") 178 | return 179 | 180 | download_url(self.url, self.root, self.filename, self.tgz_md5) 181 | 182 | # extract file 183 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: 184 | tar.extractall(path=self.root) 185 | 186 | def extra_repr(self): 187 | return "Split: {}".format("Train" if self.train is True else "Test") 188 | 189 | 190 | class CIFAR100(CIFAR10): 191 | """`CIFAR100 `_ Dataset. 192 | 193 | This is a subclass of the `CIFAR10` Dataset. 194 | """ 195 | 196 | base_folder = "cifar-100-python" 197 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 198 | filename = "cifar-100-python.tar.gz" 199 | tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85" 200 | train_list = [ 201 | ["train", "16019d7e3df5f24257cddd939b257f8d"], 202 | ] 203 | 204 | test_list = [ 205 | ["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"], 206 | ] 207 | meta = { 208 | "filename": "meta", 209 | "key": "fine_label_names", 210 | "md5": "7973b15100ade9c7d40fb424638fde48", 211 | } 212 | 213 | 214 | class TinyImageNet(ImageFolder): 215 | def __init__(self, root, transform=None, transform_list=None): 216 | 217 | super(TinyImageNet, self).__init__(root=root, transform=transform) 218 | self.transform_list = transform_list 219 | 220 | def __getitem__(self, index): 221 | 222 | path, target = self.imgs[index] 223 | img = self.loader(path) 224 | if self.transform_list is not None: 225 | img_transformed = [] 226 | for transform in self.transform_list: 227 | img_transformed.append(transform(img.copy())) 228 | img = torch.stack(img_transformed) 229 | else: 230 | img = self.transform(img) 231 | return img, target 232 | 233 | 234 | def get_dataloader(args, ddp=False): 235 | train_transforms = [] 236 | 237 | if args.dataset == "cifar100": 238 | normalize = transforms.Normalize( 239 | (0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762) 240 | ) 241 | elif args.dataset == "cifar10": 242 | normalize = transforms.Normalize( 243 | (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616) 244 | ) 245 | elif args.dataset == "CUB" or args.dataset == "tinyimagenet": 246 | normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 247 | 248 | if args.dataset == "cifar100" or args.dataset == "cifar10": 249 | for i in range(args.aug_nums): 250 | train_transform = transforms.Compose( 251 | [ 252 | transforms.RandomCrop(32, padding=4), 253 | transforms.RandomHorizontalFlip(), 254 | transforms.ToTensor(), 255 | normalize, 256 | ] 257 | ) 258 | train_transforms.append(train_transform) 259 | 260 | test_transform = transforms.Compose( 261 | [ 262 | transforms.ToTensor(), 263 | normalize, 264 | ] 265 | ) 266 | 267 | elif args.dataset == "CUB": 268 | for i in range(args.aug_nums): 269 | train_transform = transforms.Compose( 270 | [ 271 | transforms.RandomResizedCrop(224), 272 | transforms.RandomHorizontalFlip(), 273 | transforms.ToTensor(), 274 | normalize, 275 | ] 276 | ) 277 | 278 | train_transforms.append(train_transform) 279 | test_transform = transforms.Compose( 280 | [ 281 | transforms.Resize(256), 282 | transforms.CenterCrop(224), 283 | transforms.ToTensor(), 284 | normalize, 285 | ] 286 | ) 287 | 288 | elif args.dataset == "tinyimagenet": 289 | for i in range(args.aug_nums): 290 | train_transform = transforms.Compose( 291 | [ 292 | transforms.RandomResizedCrop(32), 293 | transforms.RandomHorizontalFlip(), 294 | transforms.ToTensor(), 295 | normalize, 296 | ] 297 | ) 298 | 299 | train_transforms.append(train_transform) 300 | test_transform = transforms.Compose( 301 | [transforms.Resize(32), transforms.ToTensor(), normalize] 302 | ) 303 | 304 | if args.dataset == "cifar100": 305 | trainset = CIFAR100( 306 | root=args.root, train=True, transform_list=train_transforms, download=False 307 | ) 308 | valset = CIFAR100( 309 | root=args.root, train=False, transform=test_transform, download=False 310 | ) 311 | elif args.dataset == "cifar10": 312 | trainset = CIFAR10( 313 | root=args.root, train=True, transform_list=train_transforms, download=False 314 | ) 315 | valset = CIFAR10( 316 | root=args.root, train=False, transform=test_transform, download=False 317 | ) 318 | elif args.dataset == "tinyimagenet": 319 | trainset = TinyImageNet( 320 | root=os.path.join(args.root, "train"), transform_list=train_transforms 321 | ) 322 | valset = TinyImageNet( 323 | root=os.path.join(args.root, "val"), transform=test_transform 324 | ) 325 | elif args.dataset == "CUB": 326 | trainset = CUB( 327 | root=os.path.join(args.root, "train"), transform_list=train_transforms 328 | ) 329 | valset = CUB(root=os.path.join(args.root, "test"), transform=test_transform) 330 | 331 | if not ddp: 332 | train_loader = DataLoader( 333 | trainset, 334 | batch_size=args.batch_size, 335 | shuffle=True, 336 | num_workers=args.num_workers, 337 | ) 338 | val_loader = DataLoader( 339 | valset, 340 | batch_size=args.batch_size, 341 | shuffle=False, 342 | num_workers=args.num_workers, 343 | ) 344 | else: 345 | # DistributedSampler 346 | train_sampler = torch.utils.data.distributed.DistributedSampler( 347 | trainset, 348 | shuffle=True, 349 | ) 350 | val_sampler = torch.utils.data.distributed.DistributedSampler( 351 | valset, 352 | shuffle=False, 353 | ) 354 | train_loader = torch.utils.data.DataLoader( 355 | trainset, 356 | batch_size=args.batch_size, 357 | num_workers=args.num_workers, 358 | pin_memory=True, 359 | sampler=train_sampler, 360 | ) 361 | val_loader = torch.utils.data.DataLoader( 362 | valset, 363 | batch_size=args.batch_size, 364 | num_workers=args.num_workers, 365 | pin_memory=True, 366 | sampler=val_sampler, 367 | ) 368 | 369 | return train_loader, val_loader 370 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.nn import init 8 | 9 | 10 | class AverageMeter(object): 11 | """Computes and stores the average and current value""" 12 | 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.count = 0 18 | self.sum = 0.0 19 | self.val = 0.0 20 | self.avg = 0.0 21 | 22 | def update(self, val, n=1): 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | 29 | def accuracy(output, target, topk=(1,)): 30 | """Computes the precision@k for the specified values of k""" 31 | maxk = max(topk) 32 | batch_size = target.size(0) 33 | 34 | _, pred = output.topk(maxk, 1, True, True) 35 | pred = pred.t() 36 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 37 | 38 | res = [] 39 | for k in topk: 40 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 41 | res.append(correct_k.mul_(100.0 / batch_size)) 42 | return res 43 | --------------------------------------------------------------------------------