├── helper ├── __init__.py ├── util.py └── loops.py ├── distiller_zoo ├── __init__.py ├── kl_div.py ├── mse_norm.py └── COS.py ├── train_scripts.sh ├── evaluate_scripts.sh ├── models ├── classifier.py ├── __init__.py ├── ShuffleNetv1.py ├── mobilenetv2.py ├── wrn.py ├── resnetv2.py ├── ShuffleNetv2.py ├── vgg.py ├── resnet.py └── util.py ├── README.md ├── dataset └── cifar100.py ├── evaluate_student.py └── train_student.py /helper/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /distiller_zoo/__init__.py: -------------------------------------------------------------------------------- 1 | from .kl_div import KL 2 | from .COS import Cosine 3 | from .mse_norm import NORM_MSE -------------------------------------------------------------------------------- /distiller_zoo/kl_div.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class KL(nn.Module): 8 | """KL_div""" 9 | def __init__(self, T): 10 | super(KL, self).__init__() 11 | self.T = T 12 | 13 | def forward(self, y_s, y_t): 14 | p_s = F.log_softmax(y_s/self.T, dim=1) 15 | p_t = F.softmax(y_t/self.T, dim=1) 16 | loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0] 17 | return loss 18 | -------------------------------------------------------------------------------- /train_scripts.sh: -------------------------------------------------------------------------------- 1 | python train_student.py --path_t './save/models/wrn_40_4_vanilla/ckpt_epoch_240.pth' --model_s wrn_16_2 -NT 4 -a 1 -b 55 -c 0.1 2 | python train_student.py --path_t './save/models/wrn_40_2_vanilla/ckpt_epoch_240.pth' --model_s wrn_16_2 -NT 4 -a 1.0 -b 65 -c 1.0 3 | python train_student.py --path_t './save/models/resnet56_vanilla/ckpt_epoch_240.pth' --model_s resnet20 -NT 4 -a 1 -b 10 -c 1 4 | python train_student.py --path_t './save/models/ResNet50_vanilla/ckpt_epoch_240.pth' --model_s MobileNetV2 -NT 4 -a 1.0 -b 33.0 -c 0.1 5 | python train_student.py --path_t './save/models/ResNet50_vanilla/ckpt_epoch_240.pth' --model_s vgg8 -NT 4 -a 1.0 -b 8 -c 2.0 6 | -------------------------------------------------------------------------------- /evaluate_scripts.sh: -------------------------------------------------------------------------------- 1 | #evalaute pretrained student models 2 | python evaluate_student.py --model_path './save/pretrained_student_model/T_wrn40_4_S_wrn_16_2_cifar100_CID/ckpt_epoch_240.pth' --model_s wrn_16_2 3 | python evaluate_student.py --model_path './save/pretrained_student_model/T_wrn40_2_S_wrn_16_2_cifar100_CID/ckpt_epoch_240.pth' --model_s wrn_16_2 4 | python evaluate_student.py --model_path './save/pretrained_student_model/T_resnet56_S_resnet20_cifar100_CID/ckpt_epoch_240.pth' --model_s resnet20 5 | python evaluate_student.py --model_path './save/pretrained_student_model/T_ResNet50_S_MobileNetV2_cifar100_CID/ckpt_epoch_240.pth' --model_s MobileNetV2 6 | python evaluate_student.py --model_path './save/pretrained_student_model/T_ResNet50_S_vgg8_cifar100_CID/ckpt_epoch_240.pth' --model_s vgg8 -------------------------------------------------------------------------------- /distiller_zoo/mse_norm.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn as nn 4 | import torch 5 | 6 | 7 | class NORM_MSE(nn.Module): 8 | 9 | def __init__(self): 10 | super(NORM_MSE, self).__init__() 11 | self.MS = nn.MSELoss(reduction='none') #nn.MSELoss(size_average=False) 12 | 13 | def forward(self, output, target): 14 | 15 | target = target.view(target.shape[0], -1) 16 | 17 | output = output.view(output.shape[0], -1) 18 | 19 | magnitute = torch.norm(target,dim=1) 20 | 21 | magnitute_square = magnitute**2 22 | 23 | magnitute_square = torch.reshape(magnitute_square,(output.shape[0], -1) ) 24 | 25 | 26 | loss = torch.sum( self.MS(output, target)/magnitute_square )/target.shape[0] 27 | 28 | return loss 29 | -------------------------------------------------------------------------------- /models/classifier.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn as nn 4 | 5 | 6 | ######################################### 7 | # ===== Classifiers ===== # 8 | ######################################### 9 | 10 | class LinearClassifier(nn.Module): 11 | 12 | def __init__(self, dim_in, n_label=10): 13 | super(LinearClassifier, self).__init__() 14 | 15 | self.net = nn.Linear(dim_in, n_label) 16 | 17 | def forward(self, x): 18 | return self.net(x) 19 | 20 | 21 | class NonLinearClassifier(nn.Module): 22 | 23 | def __init__(self, dim_in, n_label=10, p=0.1): 24 | super(NonLinearClassifier, self).__init__() 25 | 26 | self.net = nn.Sequential( 27 | nn.Linear(dim_in, 200), 28 | nn.Dropout(p=p), 29 | nn.BatchNorm1d(200), 30 | nn.ReLU(inplace=True), 31 | nn.Linear(200, n_label), 32 | ) 33 | 34 | def forward(self, x): 35 | return self.net(x) 36 | -------------------------------------------------------------------------------- /distiller_zoo/COS.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class Cosine(nn.Module): 9 | 10 | def __init__(self): 11 | super(Cosine, self).__init__() 12 | 13 | def forward(self, g_s, g_t): 14 | return self.similarity_loss(g_s, g_t) 15 | 16 | def similarity_loss(self, f_s, f_t): 17 | 18 | bsz = f_s.shape[0] #64* 19 | f_s = f_s.view(bsz, -1)#64*dim 20 | f_s = torch.nn.functional.normalize(f_s)#64*dim 21 | 22 | f_t = f_t.view(bsz, -1) 23 | f_t = torch.nn.functional.normalize(f_t)#64*dim 24 | 25 | G_s = torch.mm(f_s, torch.t(f_s))#64*dim 26 | # G_s = G_s / G_s.norm(2) 27 | 28 | G_t = torch.mm(f_t, torch.t(f_t)) 29 | # G_t = G_t / G_t.norm(2) 30 | 31 | G_diff = G_t - G_s 32 | 33 | #print('G_diff0', G_diff) 34 | 35 | loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz) 36 | 37 | return loss 38 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import resnet8, resnet14, resnet20, resnet32, resnet44, resnet56, resnet110, resnet8x4, resnet32x4 2 | from .resnetv2 import ResNet50, ResNet34, ResNet10 3 | from .wrn import wrn_16_1, wrn_16_2, wrn_40_1, wrn_40_2, wrn_40_4, wrn_16_10, wrn_10_10, wrn_16_4 4 | from .vgg import vgg19_bn, vgg16_bn, vgg13_bn, vgg11_bn, vgg8_bn 5 | from .mobilenetv2 import mobile_half 6 | from .ShuffleNetv1 import ShuffleV1 7 | from .ShuffleNetv2 import ShuffleV2 8 | 9 | model_dict = { 10 | 'resnet8': resnet8, 11 | 'resnet14': resnet14, 12 | 'resnet20': resnet20, 13 | 'resnet32': resnet32, 14 | 'resnet44': resnet44, 15 | 'resnet56': resnet56, 16 | 'resnet110': resnet110, 17 | 'resnet8x4': resnet8x4, 18 | 'resnet32x4': resnet32x4, 19 | 'ResNet34': ResNet34, 20 | 'ResNet50': ResNet50, 21 | 'ResNet10': ResNet10, 22 | 'wrn_16_1': wrn_16_1, 23 | 'wrn_16_2': wrn_16_2, 24 | 'wrn_10_10': wrn_10_10, 25 | 'wrn_40_1': wrn_40_1, 26 | 'wrn_40_2': wrn_40_2, 27 | 'wrn_40_4': wrn_40_4, 28 | 'wrn_16_10': wrn_16_10, 29 | 'wrn_16_4': wrn_16_4, 30 | 'vgg8': vgg8_bn, 31 | 'vgg11': vgg11_bn, 32 | 'vgg13': vgg13_bn, 33 | 'vgg16': vgg16_bn, 34 | 'vgg19': vgg19_bn, 35 | 'MobileNetV2': mobile_half, 36 | 'ShuffleV1': ShuffleV1, 37 | 'ShuffleV2': ShuffleV2, 38 | } 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Comprehensive Knowledge Distillation with Causal Intervention 2 | 3 | This repository is a PyTorch implementation of "Comprehensive Knowledge Distillation with Causal Intervention". The code is modified from [CRD], and the pretrained teachers (except WRN-40-4) are also downloaded from [CRD]. 4 | 5 | ## Requirements 6 | 7 | The code was tested on 8 | ``` 9 | Python 3.6 10 | torch 1.2.0 11 | torchvision 0.4.0 12 | ``` 13 | 14 | ## Evaluation 15 | To evaluate our pre-trained light-weight student networks, first download the folder "pretrained_student_model" from [CID models] into the "save" folder, then simply run the command below to evaluate these light-weight students: 16 | 17 | `sh evaluate_scripts.sh` 18 | 19 | ## Training 20 | To train students from scratch by distilling knowledge from teacher networks with CID, first download the pretrained teacher folder "models" from [CID models] into the "save" folder, and then simply run the command below to compress large models to smaller ones: 21 | 22 | `sh train_scripts.sh` 23 | 24 | [CID models]: https://drive.google.com/drive/folders/1s-NwnDw3VXc_r87-XHEg1iM0KhxpXlbj?usp=sharing 25 | 26 | [CRD]: https://github.com/HobbitLong/RepDistiller 27 | 28 | ## Citation 29 | If you find this code helpful, you may consider citing this paper: 30 | ```bibtex 31 | @inproceedings{deng2021comprehensive, 32 | title={Comprehensive Knowledge Distillation with Causal Intervention}, 33 | author={Deng, Xiang and Zhang, Zhongfei}, 34 | booktitle = {Proceedings of the 30th Annual Conference on Neural Information Processing Systems}, 35 | year={2021} 36 | } 37 | ``` 38 | -------------------------------------------------------------------------------- /helper/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import numpy as np 5 | 6 | def normalize( x): 7 | x_norm = torch.norm(x, p=2, dim=1).unsqueeze(1).expand_as(x) 8 | x_normalized = x.div(x_norm + 0.00001) 9 | return x_normalized 10 | 11 | 12 | 13 | def adjust_learning_rate(epoch, opt, optimizer): 14 | """Sets the learning rate to the initial LR decayed by decay rate every steep step""" 15 | steps = np.sum(epoch > np.asarray(opt.lr_decay_epochs)) 16 | if steps > 0: 17 | new_lr = opt.learning_rate * (opt.lr_decay_rate ** steps) 18 | for param_group in optimizer.param_groups: 19 | param_group['lr'] = new_lr 20 | 21 | 22 | class AverageMeter(object): 23 | """Computes and stores the average and current value""" 24 | def __init__(self): 25 | self.reset() 26 | 27 | def reset(self): 28 | self.val = 0 29 | self.avg = 0 30 | self.sum = 0 31 | self.count = 0 32 | 33 | def update(self, val, n=1): 34 | self.val = val 35 | self.sum += val * n 36 | self.count += n 37 | self.avg = self.sum / self.count 38 | 39 | 40 | def accuracy(output, target, topk=(1,)): 41 | """Computes the accuracy over the k top predictions for the specified values of k""" 42 | with torch.no_grad(): 43 | maxk = max(topk) 44 | batch_size = target.size(0) 45 | 46 | _, pred = output.topk(maxk, 1, True, True) 47 | pred = pred.t() 48 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 49 | 50 | res = [] 51 | for k in topk: 52 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 53 | res.append(correct_k.mul_(100.0 / batch_size)) 54 | return res 55 | 56 | 57 | 58 | 59 | def set_seed(seed): 60 | np.random.seed(seed) 61 | torch.manual_seed(seed) 62 | torch.cuda.manual_seed(seed) 63 | 64 | 65 | def cluster(f_s, f_t, label, num_classes): 66 | 67 | length= f_s.shape[0] 68 | 69 | list_s = [] 70 | list_t = [] 71 | 72 | for i in range(num_classes): 73 | list_s.append([]) 74 | list_t.append([]) 75 | 76 | for i in range(length): 77 | list_s[label[i]].append(f_s[i]) 78 | list_t[label[i]].append(f_t[i]) 79 | 80 | 81 | return list_s, list_t 82 | 83 | 84 | 85 | 86 | 87 | if __name__ == '__main__': 88 | 89 | pass 90 | -------------------------------------------------------------------------------- /dataset/cifar100.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | from torch.utils.data import DataLoader 5 | from torchvision import datasets, transforms 6 | from PIL import Image 7 | 8 | 9 | 10 | def get_data_folder(): 11 | 12 | data_folder = './data' 13 | 14 | if not os.path.isdir(data_folder): 15 | os.makedirs(data_folder) 16 | 17 | return data_folder 18 | 19 | 20 | class CIFAR100Instance(datasets.CIFAR100): 21 | """CIFAR100Instance Dataset. 22 | """ 23 | def __getitem__(self, index): 24 | if self.train: 25 | img, target = self.data[index], self.targets[index] 26 | else: 27 | img, target = self.test_data[index], self.test_labels[index] 28 | 29 | # doing this so that it is consistent with all other datasets 30 | # to return a PIL Image 31 | img = Image.fromarray(img) 32 | 33 | if self.transform is not None: 34 | img = self.transform(img) 35 | 36 | if self.target_transform is not None: 37 | target = self.target_transform(target) 38 | 39 | return img, target, index 40 | 41 | 42 | def get_cifar100_dataloaders(batch_size=128, num_workers=8, is_instance=False): 43 | """ 44 | cifar 100 45 | """ 46 | data_folder = get_data_folder() 47 | 48 | train_transform = transforms.Compose([ 49 | transforms.RandomCrop(32, padding=4), 50 | transforms.RandomHorizontalFlip(), 51 | transforms.ToTensor(), 52 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 53 | ]) 54 | test_transform = transforms.Compose([ 55 | transforms.ToTensor(), 56 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 57 | ]) 58 | 59 | if is_instance: 60 | train_set = CIFAR100Instance(root=data_folder, 61 | download=True, 62 | train=True, 63 | transform=train_transform) 64 | n_data = len(train_set) 65 | else: 66 | train_set = datasets.CIFAR100(root=data_folder, 67 | download=True, 68 | train=True, 69 | transform=train_transform) 70 | train_loader = DataLoader(train_set, 71 | batch_size=batch_size, 72 | shuffle=True, 73 | num_workers=num_workers) 74 | 75 | test_set = datasets.CIFAR100(root=data_folder, 76 | download=True, 77 | train=False, 78 | transform=test_transform) 79 | test_loader = DataLoader(test_set, 80 | batch_size=int(batch_size/2), 81 | shuffle=False, 82 | num_workers=int(num_workers/2)) 83 | 84 | if is_instance: 85 | return train_loader, test_loader, n_data 86 | else: 87 | return train_loader, test_loader -------------------------------------------------------------------------------- /evaluate_student.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluation 3 | """ 4 | 5 | from __future__ import print_function 6 | 7 | import argparse 8 | import torch 9 | import torch.nn as nn 10 | import torch.backends.cudnn as cudnn 11 | from models import model_dict 12 | from models.util import Reg 13 | from dataset.cifar100 import get_cifar100_dataloaders 14 | from helper.loops import validate_st 15 | import os 16 | 17 | 18 | def parse_option(): 19 | 20 | parser = argparse.ArgumentParser('Arguments for training') 21 | 22 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size') 23 | parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use') 24 | parser.add_argument('-NT', '--net_T', type=float, default=4, help='net Tempereture') 25 | 26 | # dataset 27 | parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar100'], help='dataset') 28 | 29 | 30 | # model 31 | parser.add_argument('--model_s', type=str, default='resnet8', 32 | choices=['resnet8', 'resnet14', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 33 | 'resnet8x4', 'resnet32x4', 'wrn_16_1', 'wrn_16_2', 'wrn_10_10','wrn_40_1', 'wrn_40_2', 34 | 'vgg8', 'vgg11', 'vgg13', 'vgg16', 'vgg19', 'ResNet50', 'MobileNetV2', 'ShuffleV1', 35 | 'ShuffleV2', 'ResNet34', 'wrn_16_4', 'wrn_40_4', 'wrn_16_10', 'ResNet10']) 36 | 37 | parser.add_argument('--model_path', type=str, default=None, help='student model') 38 | 39 | parser.add_argument('--hint_layer', default=-1, type=int, choices=[-1]) 40 | 41 | opt = parser.parse_args() 42 | 43 | return opt 44 | 45 | 46 | def main(): 47 | 48 | opt = parse_option() 49 | 50 | print(opt) 51 | 52 | # dataloader 53 | if opt.dataset == 'cifar100': 54 | 55 | _, val_loader, n_data = get_cifar100_dataloaders(batch_size=opt.batch_size, 56 | num_workers=opt.num_workers, 57 | is_instance=True) 58 | n_cls = 100 59 | else: 60 | raise NotImplementedError(opt.dataset) 61 | 62 | # model 63 | model_s = model_dict[opt.model_s](num_classes=n_cls) 64 | model_s.load_state_dict(torch.load(opt.model_path)['model']) 65 | 66 | data = torch.randn(2, 3, 32, 32) 67 | 68 | model_s.eval() 69 | 70 | feat_s, _= model_s(data, is_feat=True) 71 | 72 | _, Cs_h = feat_s[opt.hint_layer].shape 73 | 74 | model_s_fc_new = Reg( Cs_h*2, n_cls) 75 | model_s_fc_new.load_state_dict(torch.load(opt.model_path)['model_s_fc_new']) 76 | 77 | context = torch.load(opt.model_path)['context_old'] 78 | 79 | criterion_cls = nn.CrossEntropyLoss() 80 | 81 | if torch.cuda.is_available(): 82 | model_s_fc_new.cuda() 83 | model_s.cuda() 84 | context = context.cuda() 85 | criterion_cls.cuda() 86 | cudnn.benchmark = True 87 | 88 | 89 | test_acc, tect_acc_top5 = validate_st(val_loader, model_s, criterion_cls, opt, context, model_s_fc_new) 90 | 91 | if __name__ == '__main__': 92 | main() 93 | -------------------------------------------------------------------------------- /models/ShuffleNetv1.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNet in PyTorch. 2 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details. 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class ShuffleBlock(nn.Module): 10 | def __init__(self, groups): 11 | super(ShuffleBlock, self).__init__() 12 | self.groups = groups 13 | 14 | def forward(self, x): 15 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 16 | N,C,H,W = x.size() 17 | g = self.groups 18 | return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).reshape(N,C,H,W) 19 | 20 | 21 | class Bottleneck(nn.Module): 22 | def __init__(self, in_planes, out_planes, stride, groups, is_last=False): 23 | super(Bottleneck, self).__init__() 24 | self.is_last = is_last 25 | self.stride = stride 26 | 27 | mid_planes = int(out_planes/4) 28 | g = 1 if in_planes == 24 else groups 29 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) 30 | self.bn1 = nn.BatchNorm2d(mid_planes) 31 | self.shuffle1 = ShuffleBlock(groups=g) 32 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) 33 | self.bn2 = nn.BatchNorm2d(mid_planes) 34 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) 35 | self.bn3 = nn.BatchNorm2d(out_planes) 36 | 37 | self.shortcut = nn.Sequential() 38 | if stride == 2: 39 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) 40 | 41 | def forward(self, x): 42 | out = F.relu(self.bn1(self.conv1(x))) 43 | out = self.shuffle1(out) 44 | out = F.relu(self.bn2(self.conv2(out))) 45 | out = self.bn3(self.conv3(out)) 46 | res = self.shortcut(x) 47 | preact = torch.cat([out, res], 1) if self.stride == 2 else out+res 48 | out = F.relu(preact) 49 | # out = F.relu(torch.cat([out, res], 1)) if self.stride == 2 else F.relu(out+res) 50 | if self.is_last: 51 | return out, preact 52 | else: 53 | return out 54 | 55 | 56 | class ShuffleNet(nn.Module): 57 | def __init__(self, cfg, num_classes=10): 58 | super(ShuffleNet, self).__init__() 59 | out_planes = cfg['out_planes'] 60 | num_blocks = cfg['num_blocks'] 61 | groups = cfg['groups'] 62 | 63 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 64 | self.bn1 = nn.BatchNorm2d(24) 65 | self.in_planes = 24 66 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 67 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 68 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 69 | self.fc = nn.Linear(out_planes[2], num_classes) 70 | 71 | def _make_layer(self, out_planes, num_blocks, groups): 72 | layers = [] 73 | for i in range(num_blocks): 74 | stride = 2 if i == 0 else 1 75 | cat_planes = self.in_planes if i == 0 else 0 76 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, 77 | stride=stride, 78 | groups=groups, 79 | is_last=(i == num_blocks - 1))) 80 | self.in_planes = out_planes 81 | return nn.Sequential(*layers) 82 | 83 | def get_feat_modules(self): 84 | feat_m = nn.ModuleList([]) 85 | feat_m.append(self.conv1) 86 | feat_m.append(self.bn1) 87 | feat_m.append(self.layer1) 88 | feat_m.append(self.layer2) 89 | feat_m.append(self.layer3) 90 | return feat_m 91 | 92 | def get_bn_before_relu(self): 93 | raise NotImplementedError('ShuffleNet currently is not supported for "Overhaul" teacher') 94 | 95 | def forward(self, x, is_feat=False, preact=False): 96 | out = F.relu(self.bn1(self.conv1(x))) 97 | f0 = out 98 | out, f1_pre = self.layer1(out) 99 | f1 = out 100 | out, f2_pre = self.layer2(out) 101 | f2 = out 102 | out, f3_pre = self.layer3(out) 103 | f3 = out 104 | out = F.avg_pool2d(out, 4) 105 | out = out.view(out.size(0), -1) 106 | f4 = out 107 | out = self.fc(out) 108 | 109 | if is_feat: 110 | if preact: 111 | return [f0, f1_pre, f2_pre, f3_pre, f4], out 112 | else: 113 | return [f0, f1, f2, f3, f4], out 114 | else: 115 | return out 116 | 117 | 118 | def ShuffleV1(**kwargs): 119 | cfg = { 120 | 'out_planes': [240, 480, 960], 121 | 'num_blocks': [4, 8, 4], 122 | 'groups': 3 123 | } 124 | return ShuffleNet(cfg, **kwargs) 125 | 126 | 127 | if __name__ == '__main__': 128 | 129 | x = torch.randn(2, 3, 32, 32) 130 | net = ShuffleV1(num_classes=100) 131 | import time 132 | a = time.time() 133 | feats, logit = net(x, is_feat=True, preact=True) 134 | b = time.time() 135 | print(b - a) 136 | for f in feats: 137 | print(f.shape, f.min().item()) 138 | print(logit.shape) 139 | -------------------------------------------------------------------------------- /models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | MobileNetV2 implementation used in 3 | 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import math 9 | 10 | __all__ = ['mobilenetv2_T_w', 'mobile_half'] 11 | 12 | BN = None 13 | 14 | 15 | def conv_bn(inp, oup, stride): 16 | return nn.Sequential( 17 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 18 | nn.BatchNorm2d(oup), 19 | nn.ReLU(inplace=True) 20 | ) 21 | 22 | 23 | def conv_1x1_bn(inp, oup): 24 | return nn.Sequential( 25 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 26 | nn.BatchNorm2d(oup), 27 | nn.ReLU(inplace=True) 28 | ) 29 | 30 | 31 | class InvertedResidual(nn.Module): 32 | def __init__(self, inp, oup, stride, expand_ratio): 33 | super(InvertedResidual, self).__init__() 34 | self.blockname = None 35 | 36 | self.stride = stride 37 | assert stride in [1, 2] 38 | 39 | self.use_res_connect = self.stride == 1 and inp == oup 40 | 41 | self.conv = nn.Sequential( 42 | # pw 43 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 44 | nn.BatchNorm2d(inp * expand_ratio), 45 | nn.ReLU(inplace=True), 46 | # dw 47 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, stride, 1, groups=inp * expand_ratio, bias=False), 48 | nn.BatchNorm2d(inp * expand_ratio), 49 | nn.ReLU(inplace=True), 50 | # pw-linear 51 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 52 | nn.BatchNorm2d(oup), 53 | ) 54 | self.names = ['0', '1', '2', '3', '4', '5', '6', '7'] 55 | 56 | def forward(self, x): 57 | t = x 58 | if self.use_res_connect: 59 | return t + self.conv(x) 60 | else: 61 | return self.conv(x) 62 | 63 | 64 | class MobileNetV2(nn.Module): 65 | """mobilenetV2""" 66 | def __init__(self, T, 67 | feature_dim, 68 | input_size=32, 69 | width_mult=1., 70 | remove_avg=False): 71 | super(MobileNetV2, self).__init__() 72 | self.remove_avg = remove_avg 73 | 74 | # setting of inverted residual blocks 75 | self.interverted_residual_setting = [ 76 | # t, c, n, s 77 | [1, 16, 1, 1], 78 | [T, 24, 2, 1], 79 | [T, 32, 3, 2], 80 | [T, 64, 4, 2], 81 | [T, 96, 3, 1], 82 | [T, 160, 3, 2], 83 | [T, 320, 1, 1], 84 | ] 85 | 86 | # building first layer 87 | assert input_size % 32 == 0 88 | input_channel = int(32 * width_mult) 89 | self.conv1 = conv_bn(3, input_channel, 2) 90 | 91 | # building inverted residual blocks 92 | self.blocks = nn.ModuleList([]) 93 | for t, c, n, s in self.interverted_residual_setting: 94 | output_channel = int(c * width_mult) 95 | layers = [] 96 | strides = [s] + [1] * (n - 1) 97 | for stride in strides: 98 | layers.append( 99 | InvertedResidual(input_channel, output_channel, stride, t) 100 | ) 101 | input_channel = output_channel 102 | self.blocks.append(nn.Sequential(*layers)) 103 | 104 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 105 | self.conv2 = conv_1x1_bn(input_channel, self.last_channel) 106 | 107 | # building classifier 108 | # self.fc = nn.Sequential( 109 | # nn.Dropout(0.5), 110 | # nn.Linear(self.last_channel, feature_dim), 111 | # ) 112 | 113 | self.fc = nn.Linear(self.last_channel, feature_dim) 114 | 115 | 116 | H = input_size // (32//2) 117 | self.avgpool = nn.AvgPool2d(H, ceil_mode=True) 118 | 119 | self._initialize_weights() 120 | print(T, width_mult) 121 | 122 | def get_bn_before_relu(self): 123 | bn1 = self.blocks[1][-1].conv[-1] 124 | bn2 = self.blocks[2][-1].conv[-1] 125 | bn3 = self.blocks[4][-1].conv[-1] 126 | bn4 = self.blocks[6][-1].conv[-1] 127 | return [bn1, bn2, bn3, bn4] 128 | 129 | def get_feat_modules(self): 130 | feat_m = nn.ModuleList([]) 131 | feat_m.append(self.conv1) 132 | feat_m.append(self.blocks) 133 | return feat_m 134 | 135 | def forward(self, x, is_feat=False, preact=False): 136 | 137 | out = self.conv1(x) 138 | f0 = out 139 | 140 | out = self.blocks[0](out) 141 | out = self.blocks[1](out) 142 | f1 = out 143 | out = self.blocks[2](out) 144 | f2 = out 145 | out = self.blocks[3](out) 146 | out = self.blocks[4](out) 147 | f3 = out 148 | out = self.blocks[5](out) 149 | out = self.blocks[6](out) 150 | f4 = out 151 | 152 | out = self.conv2(out) 153 | 154 | if not self.remove_avg: 155 | out = self.avgpool(out) 156 | out = out.view(out.size(0), -1) 157 | f5 = out 158 | out = self.fc(out) 159 | 160 | if is_feat: 161 | return [f0, f1, f2, f3, f4, f5], out 162 | else: 163 | return out 164 | 165 | def _initialize_weights(self): 166 | for m in self.modules(): 167 | if isinstance(m, nn.Conv2d): 168 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 169 | m.weight.data.normal_(0, math.sqrt(2. / n)) 170 | if m.bias is not None: 171 | m.bias.data.zero_() 172 | elif isinstance(m, nn.BatchNorm2d): 173 | m.weight.data.fill_(1) 174 | m.bias.data.zero_() 175 | elif isinstance(m, nn.Linear): 176 | n = m.weight.size(1) 177 | m.weight.data.normal_(0, 0.01) 178 | m.bias.data.zero_() 179 | 180 | 181 | def mobilenetv2_T_w(T, W, feature_dim=100): 182 | model = MobileNetV2(T=T, feature_dim=feature_dim, width_mult=W) 183 | return model 184 | 185 | 186 | def mobile_half(num_classes): 187 | return mobilenetv2_T_w(6, 0.5, num_classes) 188 | 189 | 190 | if __name__ == '__main__': 191 | x = torch.randn(2, 3, 32, 32) 192 | 193 | net = mobile_half(100) 194 | 195 | feats, logit = net(x, is_feat=True, preact=True) 196 | for f in feats: 197 | print(f.shape, f.min().item()) 198 | print(logit.shape) 199 | 200 | for m in net.get_bn_before_relu(): 201 | if isinstance(m, nn.BatchNorm2d): 202 | print('pass') 203 | else: 204 | print('warning') 205 | 206 | -------------------------------------------------------------------------------- /models/wrn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | """ 7 | Original Author: Wei Yang 8 | """ 9 | 10 | __all__ = ['wrn'] 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 15 | super(BasicBlock, self).__init__() 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.relu1 = nn.ReLU(inplace=True) 18 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(out_planes) 21 | self.relu2 = nn.ReLU(inplace=True) 22 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 23 | padding=1, bias=False) 24 | self.droprate = dropRate 25 | self.equalInOut = (in_planes == out_planes) 26 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 27 | padding=0, bias=False) or None 28 | 29 | def forward(self, x): 30 | if not self.equalInOut: 31 | x = self.relu1(self.bn1(x)) 32 | else: 33 | out = self.relu1(self.bn1(x)) 34 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 35 | if self.droprate > 0: 36 | out = F.dropout(out, p=self.droprate, training=self.training) 37 | out = self.conv2(out) 38 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 39 | 40 | 41 | class NetworkBlock(nn.Module): 42 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 43 | super(NetworkBlock, self).__init__() 44 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 45 | 46 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 47 | layers = [] 48 | for i in range(nb_layers): 49 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 50 | return nn.Sequential(*layers) 51 | 52 | def forward(self, x): 53 | return self.layer(x) 54 | 55 | 56 | class WideResNet(nn.Module): 57 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 58 | super(WideResNet, self).__init__() 59 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 60 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' 61 | n = (depth - 4) // 6 62 | block = BasicBlock 63 | # 1st conv before any network block 64 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 65 | padding=1, bias=False) 66 | # 1st block 67 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 68 | # 2nd block 69 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 70 | # 3rd block 71 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 72 | # global average pooling and classifier 73 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.fc = nn.Linear(nChannels[3], num_classes) 76 | self.nChannels = nChannels[3] 77 | 78 | for m in self.modules(): 79 | if isinstance(m, nn.Conv2d): 80 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 81 | m.weight.data.normal_(0, math.sqrt(2. / n)) 82 | elif isinstance(m, nn.BatchNorm2d): 83 | m.weight.data.fill_(1) 84 | m.bias.data.zero_() 85 | elif isinstance(m, nn.Linear): 86 | m.bias.data.zero_() 87 | 88 | def get_feat_modules(self): 89 | feat_m = nn.ModuleList([]) 90 | feat_m.append(self.conv1) 91 | feat_m.append(self.block1) 92 | feat_m.append(self.block2) 93 | feat_m.append(self.block3) 94 | return feat_m 95 | 96 | def get_bn_before_relu(self): 97 | bn1 = self.block2.layer[0].bn1 98 | bn2 = self.block3.layer[0].bn1 99 | bn3 = self.bn1 100 | 101 | return [bn1, bn2, bn3] 102 | 103 | def forward(self, x, is_feat=False, preact=False): 104 | out = self.conv1(x) 105 | f0 = out 106 | out = self.block1(out) 107 | f1 = out 108 | out = self.block2(out) 109 | f2 = out 110 | out = self.block3(out) 111 | f3 = out 112 | out = self.relu(self.bn1(out)) 113 | f3_5 = out 114 | out = F.avg_pool2d(out, 8) 115 | #f4 = out 116 | out = out.view(-1, self.nChannels) 117 | f4 = out 118 | out = self.fc(out) 119 | if is_feat: 120 | if preact: 121 | f1 = self.block2.layer[0].bn1(f1) 122 | f2 = self.block3.layer[0].bn1(f2) 123 | f3 = self.bn1(f3) 124 | return [f0, f1, f2, f3, f3_5, f4], out 125 | else: 126 | return out 127 | 128 | 129 | def wrn(**kwargs): 130 | """ 131 | Constructs a Wide Residual Networks. 132 | """ 133 | model = WideResNet(**kwargs) 134 | return model 135 | 136 | 137 | def wrn_40_4(**kwargs): 138 | model = WideResNet(depth=40, widen_factor=4, **kwargs) 139 | return model 140 | 141 | def wrn_40_2(**kwargs): 142 | model = WideResNet(depth=40, widen_factor=2, **kwargs) 143 | return model 144 | 145 | 146 | def wrn_40_1(**kwargs): 147 | model = WideResNet(depth=40, widen_factor=1, **kwargs) 148 | return model 149 | 150 | 151 | def wrn_16_10(**kwargs): 152 | model = WideResNet(depth=16, widen_factor=10, **kwargs) 153 | return model 154 | 155 | def wrn_10_10(**kwargs): 156 | model = WideResNet(depth=10, widen_factor=10, **kwargs) 157 | return model 158 | 159 | def wrn_16_4(**kwargs): 160 | model = WideResNet(depth=16, widen_factor=4, **kwargs) 161 | return model 162 | 163 | def wrn_16_2(**kwargs): 164 | model = WideResNet(depth=16, widen_factor=2, **kwargs) 165 | return model 166 | 167 | 168 | def wrn_16_1(**kwargs): 169 | model = WideResNet(depth=16, widen_factor=1, **kwargs) 170 | return model 171 | 172 | 173 | if __name__ == '__main__': 174 | import torch 175 | 176 | x = torch.randn(2, 3, 32, 32) 177 | net = wrn_40_2(num_classes=100) 178 | feats, logit = net(x, is_feat=True, preact=True) 179 | 180 | for f in feats: 181 | print(f.shape, f.min().item()) 182 | print(logit.shape) 183 | 184 | for m in net.get_bn_before_relu(): 185 | if isinstance(m, nn.BatchNorm2d): 186 | print('pass') 187 | else: 188 | print('warning') 189 | -------------------------------------------------------------------------------- /models/resnetv2.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | For Pre-activation ResNet, see 'preact_resnet.py'. 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, in_planes, planes, stride=1, is_last=False): 16 | super(BasicBlock, self).__init__() 17 | self.is_last = is_last 18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.shortcut = nn.Sequential() 24 | if stride != 1 or in_planes != self.expansion * planes: 25 | self.shortcut = nn.Sequential( 26 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 27 | nn.BatchNorm2d(self.expansion * planes) 28 | ) 29 | 30 | def forward(self, x): 31 | out = F.relu(self.bn1(self.conv1(x))) 32 | out = self.bn2(self.conv2(out)) 33 | out += self.shortcut(x) 34 | preact = out 35 | out = F.relu(out) 36 | if self.is_last: 37 | return out, preact 38 | else: 39 | return out 40 | 41 | 42 | class Bottleneck(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, in_planes, planes, stride=1, is_last=False): 46 | super(Bottleneck, self).__init__() 47 | self.is_last = is_last 48 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(planes) 50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 53 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 54 | 55 | self.shortcut = nn.Sequential() 56 | if stride != 1 or in_planes != self.expansion * planes: 57 | self.shortcut = nn.Sequential( 58 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 59 | nn.BatchNorm2d(self.expansion * planes) 60 | ) 61 | 62 | def forward(self, x): 63 | out = F.relu(self.bn1(self.conv1(x))) 64 | out = F.relu(self.bn2(self.conv2(out))) 65 | out = self.bn3(self.conv3(out)) 66 | out += self.shortcut(x) 67 | preact = out 68 | out = F.relu(out) 69 | if self.is_last: 70 | return out, preact 71 | else: 72 | return out 73 | 74 | 75 | class ResNet(nn.Module): 76 | def __init__(self, block, num_blocks, num_classes=10, zero_init_residual=False): 77 | super(ResNet, self).__init__() 78 | self.in_planes = 64 79 | 80 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 81 | self.bn1 = nn.BatchNorm2d(64) 82 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 83 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 84 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 85 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 86 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 87 | self.linear = nn.Linear(512 * block.expansion, num_classes) 88 | 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 92 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 93 | nn.init.constant_(m.weight, 1) 94 | nn.init.constant_(m.bias, 0) 95 | 96 | # Zero-initialize the last BN in each residual branch, 97 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 98 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 99 | if zero_init_residual: 100 | for m in self.modules(): 101 | if isinstance(m, Bottleneck): 102 | nn.init.constant_(m.bn3.weight, 0) 103 | elif isinstance(m, BasicBlock): 104 | nn.init.constant_(m.bn2.weight, 0) 105 | 106 | def get_feat_modules(self): 107 | feat_m = nn.ModuleList([]) 108 | feat_m.append(self.conv1) 109 | feat_m.append(self.bn1) 110 | feat_m.append(self.layer1) 111 | feat_m.append(self.layer2) 112 | feat_m.append(self.layer3) 113 | feat_m.append(self.layer4) 114 | return feat_m 115 | 116 | def get_bn_before_relu(self): 117 | if isinstance(self.layer1[0], Bottleneck): 118 | bn1 = self.layer1[-1].bn3 119 | bn2 = self.layer2[-1].bn3 120 | bn3 = self.layer3[-1].bn3 121 | bn4 = self.layer4[-1].bn3 122 | elif isinstance(self.layer1[0], BasicBlock): 123 | bn1 = self.layer1[-1].bn2 124 | bn2 = self.layer2[-1].bn2 125 | bn3 = self.layer3[-1].bn2 126 | bn4 = self.layer4[-1].bn2 127 | else: 128 | raise NotImplementedError('ResNet unknown block error !!!') 129 | 130 | return [bn1, bn2, bn3, bn4] 131 | 132 | def _make_layer(self, block, planes, num_blocks, stride): 133 | strides = [stride] + [1] * (num_blocks - 1) 134 | layers = [] 135 | for i in range(num_blocks): 136 | stride = strides[i] 137 | layers.append(block(self.in_planes, planes, stride, i == num_blocks - 1)) 138 | self.in_planes = planes * block.expansion 139 | return nn.Sequential(*layers) 140 | 141 | def forward(self, x, is_feat=False, preact=False): 142 | out = F.relu(self.bn1(self.conv1(x))) 143 | f0 = out 144 | out, f1_pre = self.layer1(out) 145 | f1 = out 146 | out, f2_pre = self.layer2(out) 147 | f2 = out 148 | out, f3_pre = self.layer3(out) 149 | f3 = out 150 | out, f4_pre = self.layer4(out) 151 | f4 = out 152 | out = self.avgpool(out) 153 | out = out.view(out.size(0), -1) 154 | f5 = out 155 | out = self.linear(out) 156 | if is_feat: 157 | if preact: 158 | return [[f0, f1_pre, f2_pre, f3_pre, f4_pre, f5], out] 159 | else: 160 | return [f0, f1, f2, f3, f4, f5], out 161 | else: 162 | return out 163 | 164 | 165 | def ResNet10(**kwargs): 166 | return ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) 167 | 168 | def ResNet18(**kwargs): 169 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 170 | 171 | 172 | def ResNet34(**kwargs): 173 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 174 | 175 | 176 | def ResNet50(**kwargs): 177 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 178 | 179 | 180 | def ResNet101(**kwargs): 181 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 182 | 183 | 184 | def ResNet152(**kwargs): 185 | return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 186 | 187 | 188 | if __name__ == '__main__': 189 | net = ResNet18(num_classes=100) 190 | x = torch.randn(2, 3, 32, 32) 191 | feats, logit = net(x, is_feat=True, preact=True) 192 | 193 | for f in feats: 194 | print(f.shape, f.min().item()) 195 | print(logit.shape) 196 | 197 | for m in net.get_bn_before_relu(): 198 | if isinstance(m, nn.BatchNorm2d): 199 | print('pass') 200 | else: 201 | print('warning') 202 | -------------------------------------------------------------------------------- /models/ShuffleNetv2.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNetV2 in PyTorch. 2 | See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details. 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class ShuffleBlock(nn.Module): 10 | def __init__(self, groups=2): 11 | super(ShuffleBlock, self).__init__() 12 | self.groups = groups 13 | 14 | def forward(self, x): 15 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 16 | N, C, H, W = x.size() 17 | g = self.groups 18 | return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) 19 | 20 | 21 | class SplitBlock(nn.Module): 22 | def __init__(self, ratio): 23 | super(SplitBlock, self).__init__() 24 | self.ratio = ratio 25 | 26 | def forward(self, x): 27 | c = int(x.size(1) * self.ratio) 28 | return x[:, :c, :, :], x[:, c:, :, :] 29 | 30 | 31 | class BasicBlock(nn.Module): 32 | def __init__(self, in_channels, split_ratio=0.5, is_last=False): 33 | super(BasicBlock, self).__init__() 34 | self.is_last = is_last 35 | self.split = SplitBlock(split_ratio) 36 | in_channels = int(in_channels * split_ratio) 37 | self.conv1 = nn.Conv2d(in_channels, in_channels, 38 | kernel_size=1, bias=False) 39 | self.bn1 = nn.BatchNorm2d(in_channels) 40 | self.conv2 = nn.Conv2d(in_channels, in_channels, 41 | kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False) 42 | self.bn2 = nn.BatchNorm2d(in_channels) 43 | self.conv3 = nn.Conv2d(in_channels, in_channels, 44 | kernel_size=1, bias=False) 45 | self.bn3 = nn.BatchNorm2d(in_channels) 46 | self.shuffle = ShuffleBlock() 47 | 48 | def forward(self, x): 49 | x1, x2 = self.split(x) 50 | out = F.relu(self.bn1(self.conv1(x2))) 51 | out = self.bn2(self.conv2(out)) 52 | preact = self.bn3(self.conv3(out)) 53 | out = F.relu(preact) 54 | # out = F.relu(self.bn3(self.conv3(out))) 55 | preact = torch.cat([x1, preact], 1) 56 | out = torch.cat([x1, out], 1) 57 | out = self.shuffle(out) 58 | if self.is_last: 59 | return out, preact 60 | else: 61 | return out 62 | 63 | 64 | class DownBlock(nn.Module): 65 | def __init__(self, in_channels, out_channels): 66 | super(DownBlock, self).__init__() 67 | mid_channels = out_channels // 2 68 | # left 69 | self.conv1 = nn.Conv2d(in_channels, in_channels, 70 | kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False) 71 | self.bn1 = nn.BatchNorm2d(in_channels) 72 | self.conv2 = nn.Conv2d(in_channels, mid_channels, 73 | kernel_size=1, bias=False) 74 | self.bn2 = nn.BatchNorm2d(mid_channels) 75 | # right 76 | self.conv3 = nn.Conv2d(in_channels, mid_channels, 77 | kernel_size=1, bias=False) 78 | self.bn3 = nn.BatchNorm2d(mid_channels) 79 | self.conv4 = nn.Conv2d(mid_channels, mid_channels, 80 | kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False) 81 | self.bn4 = nn.BatchNorm2d(mid_channels) 82 | self.conv5 = nn.Conv2d(mid_channels, mid_channels, 83 | kernel_size=1, bias=False) 84 | self.bn5 = nn.BatchNorm2d(mid_channels) 85 | 86 | self.shuffle = ShuffleBlock() 87 | 88 | def forward(self, x): 89 | # left 90 | out1 = self.bn1(self.conv1(x)) 91 | out1 = F.relu(self.bn2(self.conv2(out1))) 92 | # right 93 | out2 = F.relu(self.bn3(self.conv3(x))) 94 | out2 = self.bn4(self.conv4(out2)) 95 | out2 = F.relu(self.bn5(self.conv5(out2))) 96 | # concat 97 | out = torch.cat([out1, out2], 1) 98 | out = self.shuffle(out) 99 | return out 100 | 101 | 102 | class ShuffleNetV2(nn.Module): 103 | def __init__(self, net_size, num_classes=10): 104 | super(ShuffleNetV2, self).__init__() 105 | out_channels = configs[net_size]['out_channels'] 106 | num_blocks = configs[net_size]['num_blocks'] 107 | 108 | # self.conv1 = nn.Conv2d(3, 24, kernel_size=3, 109 | # stride=1, padding=1, bias=False) 110 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 111 | self.bn1 = nn.BatchNorm2d(24) 112 | self.in_channels = 24 113 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0]) 114 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1]) 115 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2]) 116 | self.conv2 = nn.Conv2d(out_channels[2], out_channels[3], 117 | kernel_size=1, stride=1, padding=0, bias=False) 118 | self.bn2 = nn.BatchNorm2d(out_channels[3]) 119 | self.fc = nn.Linear(out_channels[3], num_classes) 120 | 121 | def _make_layer(self, out_channels, num_blocks): 122 | layers = [DownBlock(self.in_channels, out_channels)] 123 | for i in range(num_blocks): 124 | layers.append(BasicBlock(out_channels, is_last=(i == num_blocks - 1))) 125 | self.in_channels = out_channels 126 | return nn.Sequential(*layers) 127 | 128 | def get_feat_modules(self): 129 | feat_m = nn.ModuleList([]) 130 | feat_m.append(self.conv1) 131 | feat_m.append(self.bn1) 132 | feat_m.append(self.layer1) 133 | feat_m.append(self.layer2) 134 | feat_m.append(self.layer3) 135 | return feat_m 136 | 137 | def get_bn_before_relu(self): 138 | raise NotImplementedError('ShuffleNetV2 currently is not supported for "Overhaul" teacher') 139 | 140 | def forward(self, x, is_feat=False, preact=False): 141 | out = F.relu(self.bn1(self.conv1(x))) 142 | # out = F.max_pool2d(out, 3, stride=2, padding=1) 143 | f0 = out 144 | out, f1_pre = self.layer1(out) 145 | f1 = out 146 | out, f2_pre = self.layer2(out) 147 | f2 = out 148 | out, f3_pre = self.layer3(out) 149 | f3 = out 150 | out = F.relu(self.bn2(self.conv2(out))) 151 | out = F.avg_pool2d(out, 4) 152 | out = out.view(out.size(0), -1) 153 | f4 = out 154 | out = self.fc(out) 155 | if is_feat: 156 | if preact: 157 | return [f0, f1_pre, f2_pre, f3_pre, f4], out 158 | else: 159 | return [f0, f1, f2, f3, f4], out 160 | else: 161 | return out 162 | 163 | 164 | configs = { 165 | 0.2: { 166 | 'out_channels': (40, 80, 160, 512), 167 | 'num_blocks': (3, 3, 3) 168 | }, 169 | 170 | 0.3: { 171 | 'out_channels': (40, 80, 160, 512), 172 | 'num_blocks': (3, 7, 3) 173 | }, 174 | 175 | 0.5: { 176 | 'out_channels': (48, 96, 192, 1024), 177 | 'num_blocks': (3, 7, 3) 178 | }, 179 | 180 | 1: { 181 | 'out_channels': (116, 232, 464, 1024), 182 | 'num_blocks': (3, 7, 3) 183 | }, 184 | 1.5: { 185 | 'out_channels': (176, 352, 704, 1024), 186 | 'num_blocks': (3, 7, 3) 187 | }, 188 | 2: { 189 | 'out_channels': (224, 488, 976, 2048), 190 | 'num_blocks': (3, 7, 3) 191 | } 192 | } 193 | 194 | 195 | def ShuffleV2(**kwargs): 196 | model = ShuffleNetV2(net_size=1, **kwargs) 197 | return model 198 | 199 | 200 | if __name__ == '__main__': 201 | net = ShuffleV2(num_classes=100) 202 | x = torch.randn(3, 3, 32, 32) 203 | import time 204 | a = time.time() 205 | feats, logit = net(x, is_feat=True, preact=True) 206 | b = time.time() 207 | print(b - a) 208 | for f in feats: 209 | print(f.shape, f.min().item()) 210 | print(logit.shape) 211 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG for CIFAR10. FC layers are removed. 2 | (c) YANG, Wei 3 | ''' 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import math 7 | 8 | 9 | __all__ = [ 10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 11 | 'vgg19_bn', 'vgg19', 12 | ] 13 | 14 | 15 | model_urls = { 16 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 17 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 18 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 19 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 20 | } 21 | 22 | 23 | class VGG(nn.Module): 24 | 25 | def __init__(self, cfg, batch_norm=False, num_classes=1000): 26 | super(VGG, self).__init__() 27 | self.block0 = self._make_layers(cfg[0], batch_norm, 3) 28 | self.block1 = self._make_layers(cfg[1], batch_norm, cfg[0][-1]) 29 | self.block2 = self._make_layers(cfg[2], batch_norm, cfg[1][-1]) 30 | self.block3 = self._make_layers(cfg[3], batch_norm, cfg[2][-1]) 31 | self.block4 = self._make_layers(cfg[4], batch_norm, cfg[3][-1]) 32 | 33 | self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2) 34 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 35 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 36 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 37 | self.pool4 = nn.AdaptiveAvgPool2d((1, 1)) 38 | # self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 39 | 40 | self.classifier = nn.Linear(512, num_classes) 41 | self._initialize_weights() 42 | 43 | def get_feat_modules(self): 44 | feat_m = nn.ModuleList([]) 45 | feat_m.append(self.block0) 46 | feat_m.append(self.pool0) 47 | feat_m.append(self.block1) 48 | feat_m.append(self.pool1) 49 | feat_m.append(self.block2) 50 | feat_m.append(self.pool2) 51 | feat_m.append(self.block3) 52 | feat_m.append(self.pool3) 53 | feat_m.append(self.block4) 54 | feat_m.append(self.pool4) 55 | return feat_m 56 | 57 | def get_bn_before_relu(self): 58 | bn1 = self.block1[-1] 59 | bn2 = self.block2[-1] 60 | bn3 = self.block3[-1] 61 | bn4 = self.block4[-1] 62 | return [bn1, bn2, bn3, bn4] 63 | 64 | def forward(self, x, is_feat=False, preact=False): 65 | h = x.shape[2] 66 | x = F.relu(self.block0(x)) 67 | f0 = x 68 | x = self.pool0(x) 69 | x = self.block1(x) 70 | f1_pre = x 71 | x = F.relu(x) 72 | f1 = x 73 | x = self.pool1(x) 74 | x = self.block2(x) 75 | f2_pre = x 76 | x = F.relu(x) 77 | f2 = x 78 | x = self.pool2(x) 79 | x = self.block3(x) 80 | f3_pre = x 81 | x = F.relu(x) 82 | f3 = x 83 | if h == 64: 84 | x = self.pool3(x) 85 | x = self.block4(x) 86 | f4_pre = x 87 | x = F.relu(x) 88 | f4 = x 89 | x = self.pool4(x) 90 | x = x.view(x.size(0), -1) 91 | f5 = x 92 | x = self.classifier(x) 93 | 94 | if is_feat: 95 | if preact: 96 | return [f0, f1_pre, f2_pre, f3_pre, f4_pre, f5], x 97 | else: 98 | return [f0, f1, f2, f3, f4, f5], x 99 | else: 100 | return x 101 | 102 | @staticmethod 103 | def _make_layers(cfg, batch_norm=False, in_channels=3): 104 | layers = [] 105 | for v in cfg: 106 | if v == 'M': 107 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 108 | else: 109 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 110 | if batch_norm: 111 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 112 | else: 113 | layers += [conv2d, nn.ReLU(inplace=True)] 114 | in_channels = v 115 | layers = layers[:-1] 116 | return nn.Sequential(*layers) 117 | 118 | def _initialize_weights(self): 119 | for m in self.modules(): 120 | if isinstance(m, nn.Conv2d): 121 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 122 | m.weight.data.normal_(0, math.sqrt(2. / n)) 123 | if m.bias is not None: 124 | m.bias.data.zero_() 125 | elif isinstance(m, nn.BatchNorm2d): 126 | m.weight.data.fill_(1) 127 | m.bias.data.zero_() 128 | elif isinstance(m, nn.Linear): 129 | n = m.weight.size(1) 130 | m.weight.data.normal_(0, 0.01) 131 | m.bias.data.zero_() 132 | 133 | 134 | cfg = { 135 | 'A': [[64], [128], [256, 256], [512, 512], [512, 512]], 136 | 'B': [[64, 64], [128, 128], [256, 256], [512, 512], [512, 512]], 137 | 'D': [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]], 138 | 'E': [[64, 64], [128, 128], [256, 256, 256, 256], [512, 512, 512, 512], [512, 512, 512, 512]], 139 | 'S': [[64], [128], [256], [512], [512]], 140 | } 141 | 142 | 143 | def vgg8(**kwargs): 144 | """VGG 8-layer model (configuration "S") 145 | Args: 146 | pretrained (bool): If True, returns a model pre-trained on ImageNet 147 | """ 148 | model = VGG(cfg['S'], **kwargs) 149 | return model 150 | 151 | 152 | def vgg8_bn(**kwargs): 153 | """VGG 8-layer model (configuration "S") 154 | Args: 155 | pretrained (bool): If True, returns a model pre-trained on ImageNet 156 | """ 157 | model = VGG(cfg['S'], batch_norm=True, **kwargs) 158 | return model 159 | 160 | 161 | def vgg11(**kwargs): 162 | """VGG 11-layer model (configuration "A") 163 | Args: 164 | pretrained (bool): If True, returns a model pre-trained on ImageNet 165 | """ 166 | model = VGG(cfg['A'], **kwargs) 167 | return model 168 | 169 | 170 | def vgg11_bn(**kwargs): 171 | """VGG 11-layer model (configuration "A") with batch normalization""" 172 | model = VGG(cfg['A'], batch_norm=True, **kwargs) 173 | return model 174 | 175 | 176 | def vgg13(**kwargs): 177 | """VGG 13-layer model (configuration "B") 178 | Args: 179 | pretrained (bool): If True, returns a model pre-trained on ImageNet 180 | """ 181 | model = VGG(cfg['B'], **kwargs) 182 | return model 183 | 184 | 185 | def vgg13_bn(**kwargs): 186 | """VGG 13-layer model (configuration "B") with batch normalization""" 187 | model = VGG(cfg['B'], batch_norm=True, **kwargs) 188 | return model 189 | 190 | 191 | def vgg16(**kwargs): 192 | """VGG 16-layer model (configuration "D") 193 | Args: 194 | pretrained (bool): If True, returns a model pre-trained on ImageNet 195 | """ 196 | model = VGG(cfg['D'], **kwargs) 197 | return model 198 | 199 | 200 | def vgg16_bn(**kwargs): 201 | """VGG 16-layer model (configuration "D") with batch normalization""" 202 | model = VGG(cfg['D'], batch_norm=True, **kwargs) 203 | return model 204 | 205 | 206 | def vgg19(**kwargs): 207 | """VGG 19-layer model (configuration "E") 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on ImageNet 210 | """ 211 | model = VGG(cfg['E'], **kwargs) 212 | return model 213 | 214 | 215 | def vgg19_bn(**kwargs): 216 | """VGG 19-layer model (configuration 'E') with batch normalization""" 217 | model = VGG(cfg['E'], batch_norm=True, **kwargs) 218 | return model 219 | 220 | 221 | if __name__ == '__main__': 222 | import torch 223 | 224 | x = torch.randn(2, 3, 32, 32) 225 | net = vgg19_bn(num_classes=100) 226 | feats, logit = net(x, is_feat=True, preact=True) 227 | 228 | for f in feats: 229 | print(f.shape, f.min().item()) 230 | print(logit.shape) 231 | 232 | for m in net.get_bn_before_relu(): 233 | if isinstance(m, nn.BatchNorm2d): 234 | print('pass') 235 | else: 236 | print('warning') 237 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | '''Resnet for cifar dataset. 4 | Ported form 5 | https://github.com/facebook/fb.resnet.torch 6 | and 7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 8 | (c) YANG, Wei 9 | ''' 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import math 13 | 14 | 15 | __all__ = ['resnet'] 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False): 28 | super(BasicBlock, self).__init__() 29 | self.is_last = is_last 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | preact = out 53 | out = F.relu(out) 54 | if self.is_last: 55 | return out, preact 56 | else: 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False): 64 | super(Bottleneck, self).__init__() 65 | self.is_last = is_last 66 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 67 | self.bn1 = nn.BatchNorm2d(planes) 68 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 69 | padding=1, bias=False) 70 | self.bn2 = nn.BatchNorm2d(planes) 71 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 72 | self.bn3 = nn.BatchNorm2d(planes * 4) 73 | self.relu = nn.ReLU(inplace=True) 74 | self.downsample = downsample 75 | self.stride = stride 76 | 77 | def forward(self, x): 78 | residual = x 79 | 80 | out = self.conv1(x) 81 | out = self.bn1(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv2(out) 85 | out = self.bn2(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | 91 | if self.downsample is not None: 92 | residual = self.downsample(x) 93 | 94 | out += residual 95 | preact = out 96 | out = F.relu(out) 97 | if self.is_last: 98 | return out, preact 99 | else: 100 | return out 101 | 102 | 103 | class ResNet(nn.Module): 104 | 105 | def __init__(self, depth, num_filters, block_name='BasicBlock', num_classes=10): 106 | super(ResNet, self).__init__() 107 | # Model type specifies number of layers for CIFAR-10 model 108 | if block_name.lower() == 'basicblock': 109 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' 110 | n = (depth - 2) // 6 111 | block = BasicBlock 112 | elif block_name.lower() == 'bottleneck': 113 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' 114 | n = (depth - 2) // 9 115 | block = Bottleneck 116 | else: 117 | raise ValueError('block_name shoule be Basicblock or Bottleneck') 118 | 119 | self.inplanes = num_filters[0] 120 | self.conv1 = nn.Conv2d(3, num_filters[0], kernel_size=3, padding=1, 121 | bias=False) 122 | self.bn1 = nn.BatchNorm2d(num_filters[0]) 123 | self.relu = nn.ReLU(inplace=True) 124 | self.layer1 = self._make_layer(block, num_filters[1], n) 125 | self.layer2 = self._make_layer(block, num_filters[2], n, stride=2) 126 | self.layer3 = self._make_layer(block, num_filters[3], n, stride=2) 127 | self.avgpool = nn.AvgPool2d(8) 128 | self.fc = nn.Linear(num_filters[3] * block.expansion, num_classes) 129 | 130 | for m in self.modules(): 131 | if isinstance(m, nn.Conv2d): 132 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 133 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 134 | nn.init.constant_(m.weight, 1) 135 | nn.init.constant_(m.bias, 0) 136 | 137 | def _make_layer(self, block, planes, blocks, stride=1): 138 | downsample = None 139 | if stride != 1 or self.inplanes != planes * block.expansion: 140 | downsample = nn.Sequential( 141 | nn.Conv2d(self.inplanes, planes * block.expansion, 142 | kernel_size=1, stride=stride, bias=False), 143 | nn.BatchNorm2d(planes * block.expansion), 144 | ) 145 | 146 | layers = list([]) 147 | layers.append(block(self.inplanes, planes, stride, downsample, is_last=(blocks == 1))) 148 | self.inplanes = planes * block.expansion 149 | for i in range(1, blocks): 150 | layers.append(block(self.inplanes, planes, is_last=(i == blocks-1))) 151 | 152 | return nn.Sequential(*layers) 153 | 154 | def get_feat_modules(self): 155 | feat_m = nn.ModuleList([]) 156 | feat_m.append(self.conv1) 157 | feat_m.append(self.bn1) 158 | feat_m.append(self.relu) 159 | feat_m.append(self.layer1) 160 | feat_m.append(self.layer2) 161 | feat_m.append(self.layer3) 162 | return feat_m 163 | 164 | def get_bn_before_relu(self): 165 | if isinstance(self.layer1[0], Bottleneck): 166 | bn1 = self.layer1[-1].bn3 167 | bn2 = self.layer2[-1].bn3 168 | bn3 = self.layer3[-1].bn3 169 | elif isinstance(self.layer1[0], BasicBlock): 170 | bn1 = self.layer1[-1].bn2 171 | bn2 = self.layer2[-1].bn2 172 | bn3 = self.layer3[-1].bn2 173 | else: 174 | raise NotImplementedError('ResNet unknown block error !!!') 175 | 176 | return [bn1, bn2, bn3] 177 | 178 | def forward(self, x, is_feat=False, preact=False): 179 | x = self.conv1(x) 180 | x = self.bn1(x) 181 | x = self.relu(x) # 32x32 182 | f0 = x 183 | 184 | x, f1_pre = self.layer1(x) # 32x32 185 | f1 = x 186 | x, f2_pre = self.layer2(x) # 16x16 187 | f2 = x 188 | x, f3_pre = self.layer3(x) # 8x8 189 | f3 = x 190 | 191 | x = self.avgpool(x) 192 | x = x.view(x.size(0), -1) 193 | f4 = x 194 | x = self.fc(x) 195 | 196 | if is_feat: 197 | if preact: 198 | return [f0, f1_pre, f2_pre, f3_pre, f4], x 199 | else: 200 | return [f0, f1, f2, f3, f4], x 201 | else: 202 | return x 203 | 204 | 205 | def resnet8(**kwargs): 206 | return ResNet(8, [16, 16, 32, 64], 'basicblock', **kwargs) 207 | 208 | 209 | def resnet14(**kwargs): 210 | return ResNet(14, [16, 16, 32, 64], 'basicblock', **kwargs) 211 | 212 | 213 | def resnet20(**kwargs): 214 | return ResNet(20, [16, 16, 32, 64], 'basicblock', **kwargs) 215 | 216 | 217 | def resnet32(**kwargs): 218 | return ResNet(32, [16, 16, 32, 64], 'basicblock', **kwargs) 219 | 220 | 221 | def resnet44(**kwargs): 222 | return ResNet(44, [16, 16, 32, 64], 'basicblock', **kwargs) 223 | 224 | 225 | def resnet56(**kwargs): 226 | return ResNet(56, [16, 16, 32, 64], 'basicblock', **kwargs) 227 | 228 | 229 | def resnet110(**kwargs): 230 | return ResNet(110, [16, 16, 32, 64], 'basicblock', **kwargs) 231 | 232 | 233 | def resnet8x4(**kwargs): 234 | return ResNet(8, [32, 64, 128, 256], 'basicblock', **kwargs) 235 | 236 | 237 | def resnet32x4(**kwargs): 238 | return ResNet(32, [32, 64, 128, 256], 'basicblock', **kwargs) 239 | 240 | 241 | if __name__ == '__main__': 242 | import torch 243 | 244 | x = torch.randn(2, 3, 32, 32) 245 | net = resnet8x4(num_classes=20) 246 | feats, logit = net(x, is_feat=True, preact=True) 247 | 248 | for f in feats: 249 | print(f.shape, f.min().item()) 250 | print(logit.shape) 251 | 252 | for m in net.get_bn_before_relu(): 253 | if isinstance(m, nn.BatchNorm2d): 254 | print('pass') 255 | else: 256 | print('warning') 257 | -------------------------------------------------------------------------------- /train_student.py: -------------------------------------------------------------------------------- 1 | """ 2 | training framework of CID 3 | """ 4 | 5 | from __future__ import print_function 6 | 7 | import os 8 | import argparse 9 | import time 10 | 11 | import torch 12 | import torch.optim as optim 13 | import torch.nn as nn 14 | import torch.backends.cudnn as cudnn 15 | 16 | 17 | from models import model_dict 18 | from models.util import Reg 19 | 20 | from dataset.cifar100 import get_cifar100_dataloaders 21 | 22 | from helper.util import adjust_learning_rate 23 | 24 | from distiller_zoo import KL, Cosine 25 | from distiller_zoo import NORM_MSE 26 | 27 | from helper.loops import train_distill as train_init, validate, validate_st, train_distill_context as train 28 | from helper.util import set_seed 29 | 30 | 31 | def parse_option(): 32 | 33 | parser = argparse.ArgumentParser('Arguments for training') 34 | 35 | parser.add_argument('--print_freq', type=int, default=100, help='print frequency') 36 | parser.add_argument('--tb_freq', type=int, default=500, help='tb frequency') 37 | parser.add_argument('--save_freq', type=int, default=240, help='save frequency') 38 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size') 39 | parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use') 40 | parser.add_argument('--epochs', type=int, default=240, help='number of training epochs') 41 | parser.add_argument('--init_epochs', type=int, default=20, help='init training for methods') 42 | 43 | # optimization 44 | parser.add_argument('--learning_rate', type=float, default=0.05, help='learning rate') 45 | parser.add_argument('--lr_decay_epochs', type=str, default='150,180,210', help='where to decay lr, can be a list') 46 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate') 47 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay') 48 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 49 | 50 | # dataset 51 | parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar100'], help='dataset') 52 | 53 | # model 54 | parser.add_argument('--model_s', type=str, default='resnet8', 55 | choices=['resnet8', 'resnet14', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 56 | 'resnet8x4', 'resnet32x4', 'wrn_16_1', 'wrn_16_2', 'wrn_10_10','wrn_40_1', 'wrn_40_2', 57 | 'vgg8', 'vgg11', 'vgg13', 'vgg16', 'vgg19', 'ResNet50', 'MobileNetV2', 'ShuffleV1', 58 | 'ShuffleV2', 'ResNet34', 'wrn_16_4', 'wrn_40_4', 'wrn_16_10', 'ResNet10']) 59 | 60 | parser.add_argument('--path_t', type=str, default=None, help='teacher model') 61 | 62 | # distillation 63 | parser.add_argument('--distill', type=str, default='CID', choices=['CID']) 64 | parser.add_argument('--trial', type=str, default='1', help='trial id') 65 | 66 | parser.add_argument('-a', '--aa', type=float, default=1, help='weight for classification') 67 | parser.add_argument('-b', '--bb', type=float, default=None, help='weight balance for sample') 68 | parser.add_argument('-c', '--cc', type=float, default=None, help='weight balance for class') 69 | 70 | parser.add_argument('-NT', '--net_T', type=float, default=4, help='net Tempereture') 71 | 72 | parser.add_argument('-s', '--seed', type=int, default=1, help='seed') 73 | 74 | parser.add_argument('-u', '--cu', type=float, default=0, help='moving average cofficient') 75 | 76 | # last layer 77 | parser.add_argument('--hint_layer', default=-1, type=int, choices=[-1]) 78 | 79 | opt = parser.parse_args() 80 | 81 | # set different learning rate fro these models 82 | if opt.model_s in ['MobileNetV2', 'ShuffleV1', 'ShuffleV2']: 83 | opt.learning_rate = 0.01 84 | 85 | 86 | # set the path according to the environment 87 | opt.model_path = './save/student_model' 88 | 89 | iterations = opt.lr_decay_epochs.split(',') 90 | opt.lr_decay_epochs = list([]) 91 | for it in iterations: 92 | opt.lr_decay_epochs.append(int(it)) 93 | 94 | opt.model_t = get_teacher_name(opt.path_t) 95 | 96 | opt.model_name = 'S:{}_{}_{}_a:{}_b:{}_c:_{}{}'.format(opt.model_s, opt.dataset, opt.distill, 97 | opt.aa, opt.bb, opt.cc, opt.trial) 98 | 99 | opt.save_folder = os.path.join(opt.model_path, opt.model_name) 100 | if not os.path.isdir(opt.save_folder): 101 | os.makedirs(opt.save_folder) 102 | 103 | return opt 104 | 105 | 106 | def get_teacher_name(model_path): 107 | """parse teacher name""" 108 | segments = model_path.split('/')[-2].split('_') 109 | if segments[0] != 'wrn': 110 | return segments[0] 111 | else: 112 | return segments[0] + '_' + segments[1] + '_' + segments[2] 113 | 114 | 115 | def load_teacher(model_path, n_cls): 116 | print('==> loading teacher model') 117 | model_t = get_teacher_name(model_path) 118 | model = model_dict[model_t](num_classes=n_cls) 119 | model.load_state_dict(torch.load(model_path)['model']) 120 | print('==> done') 121 | return model 122 | 123 | 124 | def main(): 125 | 126 | 127 | best_acc = 0 128 | 129 | opt = parse_option() 130 | 131 | print(opt) 132 | 133 | set_seed(opt.seed) 134 | 135 | 136 | # dataloader 137 | if opt.dataset == 'cifar100': 138 | 139 | train_loader, val_loader, n_data = get_cifar100_dataloaders(batch_size=opt.batch_size, 140 | num_workers=opt.num_workers, 141 | is_instance=True) 142 | n_cls = 100 143 | else: 144 | raise NotImplementedError(opt.dataset) 145 | 146 | # model 147 | model_t = load_teacher(opt.path_t, n_cls) 148 | model_s = model_dict[opt.model_s](num_classes=n_cls) 149 | 150 | data = torch.randn(2, 3, 32, 32) 151 | model_t.eval() 152 | model_s.eval() 153 | feat_t, _ = model_t(data, is_feat=True) 154 | feat_s, _= model_s(data, is_feat=True) 155 | 156 | module_list = nn.ModuleList([]) 157 | module_list.append(model_s) 158 | 159 | trainable_list = nn.ModuleList([]) 160 | trainable_list2 = nn.ModuleList([]) 161 | 162 | trainable_list.append(model_s) 163 | 164 | criterion_cls = nn.CrossEntropyLoss() 165 | criterion_div = KL(opt.net_T) 166 | 167 | if opt.distill == 'CID': 168 | 169 | criterion_kd = NORM_MSE() 170 | 171 | criterion_cc = Cosine() 172 | 173 | _, Cs_h = feat_s[opt.hint_layer].shape 174 | 175 | model_s_fc_new = Reg( Cs_h*2, n_cls) 176 | 177 | module_list.append(model_s_fc_new) 178 | 179 | trainable_list.append(model_s_fc_new) 180 | 181 | _, Ct_h = feat_t[opt.hint_layer].shape 182 | 183 | Reger_fea = Reg( Cs_h, Ct_h) 184 | 185 | module_list.append(Reger_fea) 186 | 187 | trainable_list2.append(Reger_fea) 188 | 189 | else: 190 | raise NotImplementedError(opt.distill) 191 | 192 | 193 | 194 | criterion_list = nn.ModuleList([]) 195 | criterion_list.append(criterion_cls) 196 | criterion_list.append(criterion_div) 197 | 198 | criterion_list.append(criterion_kd) 199 | criterion_list.append(criterion_cc) 200 | 201 | 202 | # optimizer 203 | optimizer = optim.SGD([{'params': trainable_list.parameters()}, {'params': trainable_list2.parameters(), 'weight_decay': 0.0}], 204 | lr=opt.learning_rate, 205 | momentum=opt.momentum, 206 | weight_decay=opt.weight_decay, 207 | nesterov=True) 208 | 209 | # append teacher after optimizer to avoid weight_decay 210 | module_list.append(model_t) 211 | 212 | if torch.cuda.is_available(): 213 | module_list.cuda() 214 | criterion_list.cuda() 215 | cudnn.benchmark = True 216 | 217 | # validate teacher accuracy 218 | teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt) 219 | print('teacher accuracy: ', teacher_acc) 220 | 221 | 222 | # routine 223 | for epoch in range(1, opt.epochs + 1): 224 | 225 | adjust_learning_rate(epoch, opt, optimizer) 226 | print("==> training...") 227 | 228 | time1 = time.time() 229 | 230 | if epoch <= opt.init_epochs: 231 | train_acc, train_loss, context = train_init(epoch, train_loader, module_list, criterion_list, optimizer, opt) 232 | context_old = context 233 | else: 234 | context_old = context 235 | train_acc, train_loss, context = train(epoch, train_loader, module_list, criterion_list, optimizer, opt, context) 236 | 237 | if train_loss!=train_loss: 238 | return 239 | 240 | time2 = time.time() 241 | print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) 242 | 243 | 244 | test_acc, tect_acc_top5 = validate_st(val_loader, model_s, criterion_cls, opt, context_old, model_s_fc_new) 245 | 246 | 247 | 248 | if test_acc > best_acc: 249 | best_acc = test_acc 250 | state = { 251 | 'epoch': epoch, 252 | 'model': model_s.state_dict(), 253 | 'model_s_fc_new': model_s_fc_new.state_dict(), 254 | 'context_old': context_old, 255 | 'best_acc': best_acc, 256 | } 257 | save_file = os.path.join(opt.save_folder, '{}_best.pth'.format(opt.model_s)) 258 | print('saving the best model!') 259 | torch.save(state, save_file) 260 | 261 | # regular saving 262 | if epoch % opt.save_freq == 0: 263 | print('==> Saving...') 264 | state = { 265 | 'epoch': epoch, 266 | 'model': model_s.state_dict(), 267 | 'accuracy': test_acc, 268 | 'model_s_fc_new': model_s_fc_new.state_dict(), 269 | 'context_old': context_old 270 | } 271 | save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) 272 | torch.save(state, save_file) 273 | 274 | # This best accuracy is only for printing purpose. 275 | print('best accuracy:', best_acc.cpu().numpy()) 276 | 277 | # save model 278 | state = { 279 | 'opt': opt, 280 | 'model': model_s.state_dict(), 281 | 'model_s_fc_new': model_s_fc_new.state_dict(), 282 | 'context_old': context_old 283 | } 284 | save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model_s)) 285 | torch.save(state, save_file) 286 | 287 | 288 | if __name__ == '__main__': 289 | main() 290 | -------------------------------------------------------------------------------- /models/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn as nn 4 | import math 5 | 6 | 7 | class Paraphraser(nn.Module): 8 | """Paraphrasing Complex Network: Network Compression via Factor Transfer""" 9 | def __init__(self, t_shape, k=0.5, use_bn=False): 10 | super(Paraphraser, self).__init__() 11 | in_channel = t_shape[1] 12 | out_channel = int(t_shape[1] * k) 13 | self.encoder = nn.Sequential( 14 | nn.Conv2d(in_channel, in_channel, 3, 1, 1), 15 | nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(), 16 | nn.LeakyReLU(0.1, inplace=True), 17 | nn.Conv2d(in_channel, out_channel, 3, 1, 1), 18 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), 19 | nn.LeakyReLU(0.1, inplace=True), 20 | nn.Conv2d(out_channel, out_channel, 3, 1, 1), 21 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), 22 | nn.LeakyReLU(0.1, inplace=True), 23 | ) 24 | self.decoder = nn.Sequential( 25 | nn.ConvTranspose2d(out_channel, out_channel, 3, 1, 1), 26 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), 27 | nn.LeakyReLU(0.1, inplace=True), 28 | nn.ConvTranspose2d(out_channel, in_channel, 3, 1, 1), 29 | nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(), 30 | nn.LeakyReLU(0.1, inplace=True), 31 | nn.ConvTranspose2d(in_channel, in_channel, 3, 1, 1), 32 | nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(), 33 | nn.LeakyReLU(0.1, inplace=True), 34 | ) 35 | 36 | def forward(self, f_s, is_factor=False): 37 | factor = self.encoder(f_s) 38 | if is_factor: 39 | return factor 40 | rec = self.decoder(factor) 41 | return factor, rec 42 | 43 | 44 | class Translator(nn.Module): 45 | def __init__(self, s_shape, t_shape, k=0.5, use_bn=True): 46 | super(Translator, self).__init__() 47 | in_channel = s_shape[1] 48 | out_channel = int(t_shape[1] * k) 49 | self.encoder = nn.Sequential( 50 | nn.Conv2d(in_channel, in_channel, 3, 1, 1), 51 | nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(), 52 | nn.LeakyReLU(0.1, inplace=True), 53 | nn.Conv2d(in_channel, out_channel, 3, 1, 1), 54 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), 55 | nn.LeakyReLU(0.1, inplace=True), 56 | nn.Conv2d(out_channel, out_channel, 3, 1, 1), 57 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), 58 | nn.LeakyReLU(0.1, inplace=True), 59 | ) 60 | 61 | def forward(self, f_s): 62 | return self.encoder(f_s) 63 | 64 | 65 | class Connector(nn.Module): 66 | """Connect for Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons""" 67 | def __init__(self, s_shapes, t_shapes): 68 | super(Connector, self).__init__() 69 | self.s_shapes = s_shapes 70 | self.t_shapes = t_shapes 71 | 72 | self.connectors = nn.ModuleList(self._make_conenctors(s_shapes, t_shapes)) 73 | 74 | @staticmethod 75 | def _make_conenctors(s_shapes, t_shapes): 76 | assert len(s_shapes) == len(t_shapes), 'unequal length of feat list' 77 | connectors = [] 78 | for s, t in zip(s_shapes, t_shapes): 79 | if s[1] == t[1] and s[2] == t[2]: 80 | connectors.append(nn.Sequential()) 81 | else: 82 | connectors.append(ConvReg(s, t, use_relu=False)) 83 | return connectors 84 | 85 | def forward(self, g_s): 86 | out = [] 87 | for i in range(len(g_s)): 88 | out.append(self.connectors[i](g_s[i])) 89 | 90 | return out 91 | 92 | 93 | class ConnectorV2(nn.Module): 94 | """A Comprehensive Overhaul of Feature Distillation (ICCV 2019)""" 95 | def __init__(self, s_shapes, t_shapes): 96 | super(ConnectorV2, self).__init__() 97 | self.s_shapes = s_shapes 98 | self.t_shapes = t_shapes 99 | 100 | self.connectors = nn.ModuleList(self._make_conenctors(s_shapes, t_shapes)) 101 | 102 | def _make_conenctors(self, s_shapes, t_shapes): 103 | assert len(s_shapes) == len(t_shapes), 'unequal length of feat list' 104 | t_channels = [t[1] for t in t_shapes] 105 | s_channels = [s[1] for s in s_shapes] 106 | connectors = nn.ModuleList([self._build_feature_connector(t, s) 107 | for t, s in zip(t_channels, s_channels)]) 108 | return connectors 109 | 110 | @staticmethod 111 | def _build_feature_connector(t_channel, s_channel): 112 | C = [nn.Conv2d(s_channel, t_channel, kernel_size=1, stride=1, padding=0, bias=False), 113 | nn.BatchNorm2d(t_channel)] 114 | for m in C: 115 | if isinstance(m, nn.Conv2d): 116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 117 | m.weight.data.normal_(0, math.sqrt(2. / n)) 118 | elif isinstance(m, nn.BatchNorm2d): 119 | m.weight.data.fill_(1) 120 | m.bias.data.zero_() 121 | return nn.Sequential(*C) 122 | 123 | def forward(self, g_s): 124 | out = [] 125 | for i in range(len(g_s)): 126 | out.append(self.connectors[i](g_s[i])) 127 | 128 | return out 129 | 130 | 131 | class ConvReg(nn.Module): 132 | """Convolutional regression for FitNet""" 133 | def __init__(self, s_shape, t_shape, use_relu=True): 134 | super(ConvReg, self).__init__() 135 | self.use_relu = use_relu 136 | s_N, s_C, s_H, s_W = s_shape 137 | t_N, t_C, t_H, t_W = t_shape 138 | if s_H == 2 * t_H: 139 | self.conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1) 140 | elif s_H * 2 == t_H: 141 | self.conv = nn.ConvTranspose2d(s_C, t_C, kernel_size=4, stride=2, padding=1) 142 | elif s_H >= t_H: 143 | self.conv = nn.Conv2d(s_C, t_C, kernel_size=(1+s_H-t_H, 1+s_W-t_W)) 144 | else: 145 | raise NotImplemented('student size {}, teacher size {}'.format(s_H, t_H)) 146 | self.bn = nn.BatchNorm2d(t_C) 147 | self.relu = nn.ReLU(inplace=True) 148 | 149 | def forward(self, x): 150 | x = self.conv(x) 151 | if self.use_relu: 152 | return self.relu(self.bn(x)) 153 | else: 154 | return self.bn(x) 155 | 156 | 157 | class Regress(nn.Module): 158 | """Simple Linear Regression for hints""" 159 | def __init__(self, dim_in=1024, dim_out=1024): 160 | super(Regress, self).__init__() 161 | self.linear = nn.Linear(dim_in, dim_out) 162 | self.relu = nn.ReLU(inplace=True) 163 | 164 | def forward(self, x): 165 | x = x.view(x.shape[0], -1) 166 | x = self.linear(x) 167 | x = self.relu(x) 168 | return x 169 | 170 | class Reg(nn.Module): 171 | """Simple Linear Regression for hints""" 172 | def __init__(self, dim_in=1024, dim_out=1024): 173 | super(Reg, self).__init__() 174 | self.linear = nn.Linear(dim_in, dim_out) 175 | 176 | def forward(self, x): 177 | x = self.linear(x) 178 | return x 179 | 180 | 181 | class Embed(nn.Module): 182 | """Embedding module""" 183 | def __init__(self, dim_in=1024, dim_out=128): 184 | super(Embed, self).__init__() 185 | self.linear = nn.Linear(dim_in, dim_out) 186 | self.l2norm = Normalize(2) 187 | 188 | def forward(self, x): 189 | x = x.view(x.shape[0], -1) 190 | x = self.linear(x) 191 | x = self.l2norm(x) 192 | return x 193 | 194 | 195 | class LinearEmbed(nn.Module): 196 | """Linear Embedding""" 197 | def __init__(self, dim_in=1024, dim_out=128): 198 | super(LinearEmbed, self).__init__() 199 | self.linear = nn.Linear(dim_in, dim_out) 200 | 201 | def forward(self, x): 202 | x = x.view(x.shape[0], -1) 203 | x = self.linear(x) 204 | return x 205 | 206 | 207 | class MLPEmbed(nn.Module): 208 | """non-linear embed by MLP""" 209 | def __init__(self, dim_in=1024, dim_out=128): 210 | super(MLPEmbed, self).__init__() 211 | self.linear1 = nn.Linear(dim_in, 2 * dim_out) 212 | self.relu = nn.ReLU(inplace=True) 213 | self.linear2 = nn.Linear(2 * dim_out, dim_out) 214 | self.l2norm = Normalize(2) 215 | 216 | def forward(self, x): 217 | x = x.view(x.shape[0], -1) 218 | x = self.relu(self.linear1(x)) 219 | x = self.l2norm(self.linear2(x)) 220 | return x 221 | 222 | 223 | class Normalize(nn.Module): 224 | """normalization layer""" 225 | def __init__(self, power=2): 226 | super(Normalize, self).__init__() 227 | self.power = power 228 | 229 | def forward(self, x): 230 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 231 | out = x.div(norm) 232 | return out 233 | 234 | 235 | class Flatten(nn.Module): 236 | """flatten module""" 237 | def __init__(self): 238 | super(Flatten, self).__init__() 239 | 240 | def forward(self, feat): 241 | return feat.view(feat.size(0), -1) 242 | 243 | 244 | class PoolEmbed(nn.Module): 245 | """pool and embed""" 246 | def __init__(self, layer=0, dim_out=128, pool_type='avg'): 247 | super().__init__() 248 | if layer == 0: 249 | pool_size = 8 250 | nChannels = 16 251 | elif layer == 1: 252 | pool_size = 8 253 | nChannels = 16 254 | elif layer == 2: 255 | pool_size = 6 256 | nChannels = 32 257 | elif layer == 3: 258 | pool_size = 4 259 | nChannels = 64 260 | elif layer == 4: 261 | pool_size = 1 262 | nChannels = 64 263 | else: 264 | raise NotImplementedError('layer not supported: {}'.format(layer)) 265 | 266 | self.embed = nn.Sequential() 267 | if layer <= 3: 268 | if pool_type == 'max': 269 | self.embed.add_module('MaxPool', nn.AdaptiveMaxPool2d((pool_size, pool_size))) 270 | elif pool_type == 'avg': 271 | self.embed.add_module('AvgPool', nn.AdaptiveAvgPool2d((pool_size, pool_size))) 272 | 273 | self.embed.add_module('Flatten', Flatten()) 274 | self.embed.add_module('Linear', nn.Linear(nChannels*pool_size*pool_size, dim_out)) 275 | self.embed.add_module('Normalize', Normalize(2)) 276 | 277 | def forward(self, x): 278 | return self.embed(x) 279 | 280 | 281 | if __name__ == '__main__': 282 | import torch 283 | 284 | g_s = [ 285 | torch.randn(2, 16, 16, 16), 286 | torch.randn(2, 32, 8, 8), 287 | torch.randn(2, 64, 4, 4), 288 | ] 289 | g_t = [ 290 | torch.randn(2, 32, 16, 16), 291 | torch.randn(2, 64, 8, 8), 292 | torch.randn(2, 128, 4, 4), 293 | ] 294 | s_shapes = [s.shape for s in g_s] 295 | t_shapes = [t.shape for t in g_t] 296 | 297 | net = ConnectorV2(s_shapes, t_shapes) 298 | out = net(g_s) 299 | for f in out: 300 | print(f.shape) 301 | -------------------------------------------------------------------------------- /helper/loops.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | 3 | import sys 4 | import time 5 | import torch 6 | 7 | import torch.nn as nn 8 | 9 | from .util import AverageMeter, accuracy 10 | from helper.util import cluster 11 | 12 | 13 | 14 | 15 | def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, opt): 16 | 17 | # set modules as train() 18 | for module in module_list: 19 | module.train() 20 | 21 | # set teacher as eval() 22 | module_list[-1].eval() 23 | 24 | #criterion_cls = criterion_list[0] 25 | criterion_kl = criterion_list[1] 26 | criterion_mse = criterion_list[2] 27 | criterion_sp = criterion_list[3] 28 | 29 | softmax = nn.Softmax(dim=1).cuda() 30 | 31 | 32 | model_s = module_list[0] 33 | model_t = module_list[-1] 34 | 35 | try: 36 | context_new = torch.zeros( model_s.fc.weight.shape, dtype=torch.float32).cuda() 37 | current_num = torch.zeros(model_s.fc.weight.shape[0], dtype=torch.float32) 38 | class_num = model_s.fc.weight.shape[0] 39 | except: 40 | try: 41 | context_new = torch.zeros( model_s.linear.weight.shape, dtype=torch.float32).cuda() 42 | current_num = torch.zeros(model_s.linear.weight.shape[0], dtype=torch.float32) 43 | class_num = model_s.linear.weight.shape[0] 44 | except: 45 | context_new = torch.zeros( model_s.classifier.weight.shape, dtype=torch.float32).cuda() 46 | current_num = torch.zeros(model_s.classifier.weight.shape[0], dtype=torch.float32) 47 | class_num = model_s.classifier.weight.shape[0] 48 | 49 | batch_time = AverageMeter() 50 | data_time = AverageMeter() 51 | losses = AverageMeter() 52 | top1 = AverageMeter() 53 | top5 = AverageMeter() 54 | 55 | end = time.time() 56 | 57 | for idx, data in enumerate(train_loader): 58 | 59 | input, target, index = data 60 | 61 | data_time.update(time.time() - end) 62 | 63 | input = input.float() 64 | if torch.cuda.is_available(): 65 | input = input.cuda() 66 | target = target.cuda() 67 | 68 | # ===================forward===================== 69 | preact = False 70 | 71 | feat_s, logit_s = model_s(input, is_feat=True, preact=preact) 72 | 73 | with torch.no_grad(): 74 | feat_t, logit_t = model_t(input, is_feat=True, preact=preact) 75 | feat_t = [f.detach() for f in feat_t] 76 | 77 | 78 | 79 | if epoch==opt.init_epochs: 80 | 81 | fea_s = feat_s[opt.hint_layer].detach() 82 | 83 | soft_t = softmax(logit_t/opt.net_T) 84 | 85 | for i in range( len(target) ): 86 | context_new[target[i]] = context_new[target[i]]*( current_num[target[i]]/(current_num[target[i]]+soft_t[i][target[i]]) )+ fea_s[i]*(soft_t[i][target[i]]/(current_num[target[i]]+soft_t[i][target[i]]) ) 87 | current_num[target[i]]+= soft_t[i][target[i]] 88 | 89 | # loss 90 | loss_kl = criterion_kl(logit_s, logit_t) 91 | 92 | fea_reg = module_list[2] 93 | f_s = fea_reg(feat_s[opt.hint_layer]) 94 | f_t = feat_t[opt.hint_layer] 95 | 96 | loss_sample = criterion_mse(f_s, f_t) 97 | 98 | list_s, list_t = cluster(feat_s[opt.hint_layer], f_t, target, class_num) 99 | 100 | involve_class = 0 101 | 102 | loss_class=0.0 103 | 104 | for k in range( len(list_s) ): 105 | 106 | cur_len = len( list_s[k] ) 107 | 108 | if cur_len>=2: 109 | cur_f_s = torch.stack(list_s[k]) 110 | cur_f_t = torch.stack(list_t[k]) 111 | 112 | loss_class+= criterion_sp(cur_f_s, cur_f_t) 113 | 114 | involve_class += 1 115 | 116 | 117 | if involve_class==0: 118 | loss_class = 0.0 119 | else: 120 | loss_class = loss_class/involve_class 121 | 122 | 123 | loss = opt.aa*loss_kl + opt.bb * loss_sample + opt.cc * loss_class 124 | 125 | acc1, acc5 = accuracy(logit_s, target, topk=(1, 5)) 126 | losses.update(loss.item(), input.size(0)) 127 | top1.update(acc1[0], input.size(0)) 128 | top5.update(acc5[0], input.size(0)) 129 | 130 | # ===================backward===================== 131 | optimizer.zero_grad() 132 | loss.backward() 133 | optimizer.step() 134 | 135 | # ===================meters===================== 136 | batch_time.update(time.time() - end) 137 | end = time.time() 138 | 139 | # print info 140 | if idx % opt.print_freq == 0: 141 | print('Epoch: [{0}][{1}/{2}]\t' 142 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 143 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 144 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 145 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 146 | 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 147 | epoch, idx, len(train_loader), batch_time=batch_time, 148 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 149 | sys.stdout.flush() 150 | 151 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 152 | .format(top1=top1, top5=top5)) 153 | 154 | return top1.avg, losses.avg, context_new 155 | 156 | 157 | 158 | def train_distill_context(epoch, train_loader, module_list, criterion_list, optimizer, opt, context): 159 | """One epoch distillation""" 160 | # set modules as train() 161 | for module in module_list: 162 | module.train() 163 | 164 | # set teacher as eval() 165 | module_list[-1].eval() 166 | 167 | 168 | context_new = torch.zeros(context.shape, dtype=torch.float32).cuda() 169 | 170 | current_num = torch.zeros(context.shape[0], dtype=torch.float32) 171 | 172 | 173 | criterion_cls = criterion_list[0] 174 | criterion_kl = criterion_list[1] 175 | criterion_mse = criterion_list[2] 176 | criterion_sp = criterion_list[3] 177 | 178 | softmax = nn.Softmax(dim=1).cuda() 179 | 180 | 181 | model_s = module_list[0] 182 | model_s_fc_new = module_list[1] 183 | model_t = module_list[-1] 184 | 185 | batch_time = AverageMeter() 186 | data_time = AverageMeter() 187 | losses = AverageMeter() 188 | top1 = AverageMeter() 189 | top5 = AverageMeter() 190 | 191 | end = time.time() 192 | 193 | for idx, data in enumerate(train_loader): 194 | 195 | input, target, index = data 196 | 197 | data_time.update(time.time() - end) 198 | 199 | input = input.float() 200 | if torch.cuda.is_available(): 201 | input = input.cuda() 202 | target = target.cuda() 203 | 204 | # ===================forward===================== 205 | preact = False 206 | 207 | feat_s, logit_s = model_s(input, is_feat=True, preact=preact) 208 | 209 | 210 | with torch.no_grad(): 211 | feat_t, logit_t = model_t(input, is_feat=True, preact=preact) 212 | feat_t = [f.detach() for f in feat_t] 213 | 214 | fea_s = feat_s[opt.hint_layer].detach() 215 | 216 | soft_t = softmax(logit_t/opt.net_T) 217 | 218 | for i in range( len(target) ): 219 | context_new[target[i]] = context_new[target[i]]*( current_num[target[i]]/(current_num[target[i]]+soft_t[i][target[i]]) ) + fea_s[i]*(soft_t[i][target[i]]/(current_num[target[i]]+soft_t[i][target[i]]) ) 220 | current_num[target[i]]+=soft_t[i][target[i]] 221 | 222 | 223 | p = softmax(logit_s.detach()/opt.net_T) 224 | 225 | sam_contxt = torch.mm(p, context) 226 | 227 | f_new = torch.cat((feat_s[opt.hint_layer], sam_contxt),1) 228 | 229 | logit_s_new = model_s_fc_new(f_new) 230 | 231 | 232 | loss_cls = criterion_cls(logit_s_new, target) 233 | 234 | loss_kl = criterion_kl(logit_s, logit_t) 235 | 236 | 237 | class_num = model_s_fc_new.linear.weight.shape[0] 238 | 239 | fea_reg = module_list[2] 240 | f_s = fea_reg(feat_s[opt.hint_layer]) 241 | f_t = feat_t[opt.hint_layer] 242 | 243 | loss_sample = criterion_mse(f_s, f_t) 244 | 245 | list_s, list_t = cluster(feat_s[opt.hint_layer], f_t, target, class_num) 246 | 247 | involve_class = 0 248 | 249 | loss_class=0.0 250 | 251 | for k in range( len(list_s) ): 252 | 253 | cur_len = len( list_s[k] ) 254 | 255 | if cur_len>=2: 256 | 257 | cur_f_s = torch.stack(list_s[k]) 258 | cur_f_t = torch.stack(list_t[k]) 259 | 260 | loss_class+= criterion_sp(cur_f_s, cur_f_t) 261 | 262 | involve_class += 1 263 | 264 | 265 | if involve_class==0: 266 | loss_class = 0.0 267 | else: 268 | loss_class = loss_class/involve_class 269 | 270 | 271 | loss = opt.aa*(loss_cls + loss_kl) + opt.bb * loss_sample + opt.cc * loss_class 272 | 273 | acc1, acc5 = accuracy(logit_s_new, target, topk=(1, 5)) 274 | losses.update(loss.item(), input.size(0)) 275 | top1.update(acc1[0], input.size(0)) 276 | top5.update(acc5[0], input.size(0)) 277 | 278 | # ===================backward===================== 279 | optimizer.zero_grad() 280 | loss.backward() 281 | optimizer.step() 282 | 283 | # ===================meters===================== 284 | batch_time.update(time.time() - end) 285 | end = time.time() 286 | 287 | # print info 288 | if idx % opt.print_freq == 0: 289 | print('Epoch: [{0}][{1}/{2}]\t' 290 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 291 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 292 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 293 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' 294 | 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 295 | epoch, idx, len(train_loader), batch_time=batch_time, 296 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 297 | sys.stdout.flush() 298 | 299 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 300 | .format(top1=top1, top5=top5)) 301 | 302 | 303 | context_new = opt.cu*context + (1-opt.cu)*context_new 304 | 305 | return top1.avg, losses.avg, context_new 306 | 307 | 308 | def validate(val_loader, model, criterion, opt): 309 | """validation""" 310 | batch_time = AverageMeter() 311 | losses = AverageMeter() 312 | top1 = AverageMeter() 313 | top5 = AverageMeter() 314 | 315 | # switch to evaluate mode 316 | model.eval() 317 | 318 | with torch.no_grad(): 319 | end = time.time() 320 | for idx, (input, target) in enumerate(val_loader): 321 | 322 | input = input.float() 323 | if torch.cuda.is_available(): 324 | input = input.cuda() 325 | target = target.cuda() 326 | 327 | # compute output 328 | output = model(input) 329 | loss = criterion(output, target) 330 | 331 | # measure accuracy and record loss 332 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 333 | losses.update(loss.item(), input.size(0)) 334 | top1.update(acc1[0], input.size(0)) 335 | top5.update(acc5[0], input.size(0)) 336 | 337 | # measure elapsed time 338 | batch_time.update(time.time() - end) 339 | end = time.time() 340 | 341 | 342 | print(' * test Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 343 | .format(top1=top1, top5=top5)) 344 | 345 | return top1.avg, top5.avg, losses.avg 346 | 347 | 348 | def validate_st(val_loader, model, criterion, opt, context, model_fc_new): 349 | """validation""" 350 | batch_time = AverageMeter() 351 | 352 | top1_new = AverageMeter() 353 | top5_new = AverageMeter() 354 | 355 | softmax = nn.Softmax(dim=1).cuda() 356 | 357 | # switch to evaluate mode 358 | model.eval() 359 | model_fc_new.eval() 360 | 361 | with torch.no_grad(): 362 | end = time.time() 363 | for idx, (input, target) in enumerate(val_loader): 364 | 365 | input = input.float() 366 | if torch.cuda.is_available(): 367 | input = input.cuda() 368 | target = target.cuda() 369 | 370 | 371 | # compute output 372 | feat, output = model(input, is_feat=True, preact=False) 373 | 374 | p = softmax(output/opt.net_T) 375 | 376 | sam_contxt = torch.mm(p, context) 377 | 378 | f_new = torch.cat((feat[opt.hint_layer], sam_contxt),1) 379 | 380 | output_new = model_fc_new(f_new) 381 | 382 | 383 | acc1_new, acc5_new = accuracy(output_new, target, topk=(1, 5)) 384 | top1_new.update(acc1_new[0], input.size(0)) 385 | top5_new.update(acc5_new[0], input.size(0)) 386 | 387 | # measure elapsed time 388 | batch_time.update(time.time() - end) 389 | end = time.time() 390 | 391 | print(' * Test Acc@1 {top1_new.avg:.3f} Acc@5 {top5_new.avg:.3f}' 392 | .format(top1_new=top1_new, top5_new=top5_new)) 393 | 394 | return top1_new.avg, top5_new.avg 395 | --------------------------------------------------------------------------------