├── .DS_Store ├── README.md ├── crd ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-39.pyc │ ├── criterion.cpython-36.pyc │ ├── criterion.cpython-39.pyc │ ├── memory.cpython-36.pyc │ └── memory.cpython-39.pyc ├── criterion.py └── memory.py ├── dataset ├── __pycache__ │ ├── cifar100.cpython-36.pyc │ ├── cifar100.cpython-39.pyc │ ├── imagenet.cpython-36.pyc │ ├── imagenet.cpython-39.pyc │ ├── imagenet_dali.cpython-36.pyc │ └── imagenet_dali.cpython-39.pyc ├── base.py ├── cifar100.py ├── imagenet.py └── imagenet_dali.py ├── distiller_zoo ├── AT.py ├── FitNet.py ├── KD.py ├── SP.py ├── SemCKD.py ├── VID.py ├── __init__.py └── __pycache__ │ ├── AT.cpython-36.pyc │ ├── AT.cpython-39.pyc │ ├── FitNet.cpython-36.pyc │ ├── FitNet.cpython-39.pyc │ ├── KD.cpython-36.pyc │ ├── KD.cpython-39.pyc │ ├── SP.cpython-36.pyc │ ├── SP.cpython-39.pyc │ ├── SemCKD.cpython-36.pyc │ ├── SemCKD.cpython-39.pyc │ ├── VID.cpython-36.pyc │ ├── VID.cpython-39.pyc │ ├── __init__.cpython-36.pyc │ └── __init__.cpython-39.pyc ├── helper ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-39.pyc │ ├── loops.cpython-36.pyc │ ├── loops.cpython-39.pyc │ ├── util.cpython-36.pyc │ └── util.cpython-39.pyc ├── loops.py └── util.py ├── images ├── .DS_Store ├── SimKD_result.png └── cifar100_result.png ├── models ├── ShuffleNetv1.py ├── ShuffleNetv2.py ├── __init__.py ├── __pycache__ │ ├── ShuffleNetv1.cpython-36.pyc │ ├── ShuffleNetv1.cpython-39.pyc │ ├── ShuffleNetv2.cpython-36.pyc │ ├── ShuffleNetv2.cpython-39.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-39.pyc │ ├── mobilenetv2.cpython-36.pyc │ ├── mobilenetv2.cpython-39.pyc │ ├── mobilenetv2_imagenet.cpython-36.pyc │ ├── mobilenetv2_imagenet.cpython-39.pyc │ ├── resnet.cpython-36.pyc │ ├── resnet.cpython-39.pyc │ ├── resnet_imagenet.cpython-36.pyc │ ├── resnet_imagenet.cpython-39.pyc │ ├── shuffleNetv2_imagenet.cpython-36.pyc │ ├── shuffleNetv2_imagenet.cpython-39.pyc │ ├── util.cpython-36.pyc │ ├── util.cpython-39.pyc │ ├── vgg.cpython-36.pyc │ └── vgg.cpython-39.pyc ├── mobilenetv2.py ├── mobilenetv2_imagenet.py ├── resnet.py ├── resnet_imagenet.py ├── shuffleNetv2_imagenet.py ├── util.py └── vgg.py ├── scripts ├── run_distill.sh └── run_vanilla.sh ├── train_student.py └── train_teacher.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SimKD 2 | 3 | Knowledge Distillation with the Reused Teacher Classifier (CVPR-2022) https://arxiv.org/abs/2203.14001 4 | 5 | # Toolbox for KD research 6 | 7 | This repository aims to provide a compact and easy-to-use implementation of several representative knowledge distillation approaches on standard image classification tasks (e.g., CIFAR100, ImageNet). 8 | 9 | - Generally, these KD approaches include a **classification loss**, a **logit-level distillation loss**, and an additional **feature distillation loss**. For fair comparison and ease of tuning, *we fix the hyper-parameters for the first two loss terms as **one** throughout all experiments.* (`--cls 1 --div 1`) 10 | 11 | - The following approaches are currently supported by this toolbox, covering vanilla KD, feature-map distillation/feature-embedding distillation, instance-level distillation/pairwise-level distillation: 12 | - [x] [Vanilla KD](https://arxiv.org/abs/1503.02531), [FitNet](https://arxiv.org/abs/1412.6550) [ICLR-2015], [AT](https://arxiv.org/abs/1612.03928) [ICLR-2017], [SP](https://arxiv.org/abs/1612.03928) [CVPR-2019], [VID](https://openaccess.thecvf.com/content_CVPR_2019/papers/Ahn_Variational_Information_Distillation_for_Knowledge_Transfer_CVPR_2019_paper.pdf) [CVPR-2019] 13 | - [x] [CRD](https://arxiv.org/abs/1910.10699) [ICLR-2020], [SRRL](https://openreview.net/forum?id=ZzwDy_wiWv) [ICLR-2021], [SemCKD](https://arxiv.org/abs/2012.03236) [AAAI-2021] 14 | - [ ] KR [CVPR-2021] 15 | - [x] [SimKD](https://arxiv.org/abs/2203.14001) [CVPR-2022] 16 | 17 | - This toolbox is built on a [open-source benchmark](https://github.com/HobbitLong/RepDistiller) and our [previous repository](https://github.com/DefangChen/SemCKD). The implementation of more KD approaches can be found there. 18 | 19 | - Computing Infrastructure: 20 | - We use one NVIDIA GeForce RTX 2080Ti GPU for CIFAR-100 experiments. The PyTorch version is 1.0. We use four NVIDIA A40 GPUs for ImageNet experiments. The PyTorch version is 1.10. 21 | - As for ImageNet, we use [DALI](https://github.com/NVIDIA/DALI) for data loading and pre-processing. 22 | 23 | - The current codes have been reorganized and we have not tested them thoroughly. If you have any questions, please contact us without hesitation. 24 | 25 | - Please put the CIFAR-100 and ImageNet dataset in the `../data/`. 26 | 27 | ## Get the pretrained teacher models 28 | 29 | ```bash 30 | # CIFAR-100 31 | python train_teacher.py --batch_size 64 --epochs 240 --dataset cifar100 --model resnet32x4 --learning_rate 0.05 --lr_decay_epochs 150,180,210 --weight_decay 5e-4 --trial 0 --gpu_id 0 32 | 33 | # ImageNet 34 | python train_teacher.py --batch_size 256 --epochs 120 --dataset imagenet --model ResNet18 --learning_rate 0.1 --lr_decay_epochs 30,60,90 --weight_decay 1e-4 --num_workers 32 --gpu_id 0,1,2,3 --dist-url tcp://127.0.0.1:23333 --multiprocessing-distributed --dali gpu --trial 0 35 | ``` 36 | 37 | The pretrained teacher models used in our paper are provided in this link [[GoogleDrive]](https://drive.google.com/drive/folders/1j7b8TmftKIRC7ChUwAqVWPIocSiacvP4?usp=sharing). 38 | 39 | ## Train the student models with various KD approaches 40 | 41 | ```bash 42 | # CIFAR-100 43 | python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill simkd --model_s resnet8x4 -c 0 -d 0 -b 1 --trial 0 44 | 45 | # ImageNet 46 | python train_student.py --path-t './save/teachers/models/ResNet50_vanilla/ResNet50_best.pth' --batch_size 256 --epochs 120 --dataset imagenet --model_s ResNet18 --distill simkd -c 0 -d 0 -b 1 --learning_rate 0.1 --lr_decay_epochs 30,60,90 --weight_decay 1e-4 --num_workers 32 --gpu_id 0,1,2,3 --dist-url tcp://127.0.0.1:23444 --multiprocessing-distributed --dali gpu --trial 0 47 | ``` 48 | More scripts are provided in `./scripts` 49 | 50 | ## Some results on CIFAR-100 51 | 52 | | | ResNet-8x4 | VGG-8 | ShuffleNetV2x1.5 | 53 | | ----- | ---- | ---- | ---- | 54 | | **Student**| 73.09 | 70.46 | 74.15 | 55 | | KD | 74.42 | 72.73 | 76.82 | 56 | | FitNet | 74.32 | 72.91 | 77.12 | 57 | | AT | 75.07 | 71.90 | 77.51 | 58 | | SP | 74.29 | 73.12 | 77.18 | 59 | | VID | 74.55 | 73.19 | 77.11 | 60 | | CRD | 75.59 | 73.54 | 77.66 | 61 | | SRRL | 75.39 | 73.23 | 77.55 | 62 | | SemCKD | 76.23 | 75.27 | 79.13 | 63 | | SimKD (f=8) | **76.73** | 74.74 | 78.96| 64 | | SimKD (f=4) | **77.88** | **75.62** | **79.48**| 65 | | SimKD (f=2) | **78.08** | **75.76** | **79.54** | 66 | | **Teacher (ResNet-32x4)** | 79.42 | 79.42 | 79.42 | 67 | 68 | 69 | ![result](./images/SimKD_result.png) 70 |
(Left) The cross-entropy loss between model predictions and test labels.
71 | (Right) The top-1 test accuracy (%) (Student: ResNet-8x4, Teacher: ResNet-32x4).
72 | 73 | 74 | ## Citation 75 | If you find this repository useful, please consider citing the following paper: 76 | 77 | 78 | ``` 79 | @inproceedings{chen2022simkd, 80 | title={Knowledge Distillation with the Reused Teacher Classifier}, 81 | author={Chen, Defang and Mei, Jian-Ping and Zhang, Hailin and Wang, Can and Feng, Yan and Chen, Chun}, 82 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 83 | pages={11933--11942}, 84 | year={2022} 85 | } 86 | ``` 87 | ``` 88 | @inproceedings{chen2021cross, 89 | author = {Defang Chen and Jian{-}Ping Mei and Yuan Zhang and Can Wang and Zhe Wang and Yan Feng and Chun Chen}, 90 | title = {Cross-Layer Distillation with Semantic Calibration}, 91 | booktitle = {Proceedings of the AAAI Conference on Artificial Intelligence}, 92 | pages = {7028--7036}, 93 | year = {2021}, 94 | } 95 | ``` 96 | 97 | -------------------------------------------------------------------------------- /crd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/crd/__init__.py -------------------------------------------------------------------------------- /crd/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/crd/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /crd/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/crd/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /crd/__pycache__/criterion.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/crd/__pycache__/criterion.cpython-36.pyc -------------------------------------------------------------------------------- /crd/__pycache__/criterion.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/crd/__pycache__/criterion.cpython-39.pyc -------------------------------------------------------------------------------- /crd/__pycache__/memory.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/crd/__pycache__/memory.cpython-36.pyc -------------------------------------------------------------------------------- /crd/__pycache__/memory.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/crd/__pycache__/memory.cpython-39.pyc -------------------------------------------------------------------------------- /crd/criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .memory import ContrastMemory 4 | 5 | eps = 1e-7 6 | 7 | 8 | class CRDLoss(nn.Module): 9 | """CRD Loss function 10 | includes two symmetric parts: 11 | (a) using teacher as anchor, choose positive and negatives over the student side 12 | (b) using student as anchor, choose positive and negatives over the teacher side 13 | Args: 14 | opt.s_dim: the dimension of student's feature 15 | opt.t_dim: the dimension of teacher's feature 16 | opt.feat_dim: the dimension of the projection space 17 | opt.nce_k: number of negatives paired with each positive 18 | opt.nce_t: the temperature 19 | opt.nce_m: the momentum for updating the memory buffer 20 | opt.n_data: the number of samples in the training set, therefor the memory buffer is: opt.n_data x opt.feat_dim 21 | """ 22 | def __init__(self, opt): 23 | super(CRDLoss, self).__init__() 24 | self.embed_s = Embed(opt.s_dim, opt.feat_dim) 25 | self.embed_t = Embed(opt.t_dim, opt.feat_dim) 26 | self.contrast = ContrastMemory(opt.feat_dim, opt.n_data, opt.nce_k, opt.nce_t, opt.nce_m) 27 | self.criterion_t = ContrastLoss(opt.n_data) 28 | self.criterion_s = ContrastLoss(opt.n_data) 29 | 30 | def forward(self, f_s, f_t, idx, contrast_idx=None): 31 | """ 32 | Args: 33 | f_s: the feature of student network, size [batch_size, s_dim] 34 | f_t: the feature of teacher network, size [batch_size, t_dim] 35 | idx: the indices of these positive samples in the dataset, size [batch_size] 36 | contrast_idx: the indices of negative samples, size [batch_size, nce_k] 37 | Returns: 38 | The contrastive loss 39 | """ 40 | f_s = self.embed_s(f_s) 41 | f_t = self.embed_t(f_t) 42 | out_s, out_t = self.contrast(f_s, f_t, idx, contrast_idx) 43 | s_loss = self.criterion_s(out_s) 44 | t_loss = self.criterion_t(out_t) 45 | loss = s_loss + t_loss 46 | return loss 47 | 48 | 49 | class ContrastLoss(nn.Module): 50 | """ 51 | contrastive loss, corresponding to Eq (18) 52 | """ 53 | def __init__(self, n_data): 54 | super(ContrastLoss, self).__init__() 55 | self.n_data = n_data 56 | 57 | def forward(self, x): 58 | bsz = x.shape[0] 59 | m = x.size(1) - 1 60 | 61 | # noise distribution 62 | Pn = 1 / float(self.n_data) 63 | 64 | # loss for positive pair 65 | P_pos = x.select(1, 0) 66 | log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_() 67 | 68 | # loss for K negative pair 69 | P_neg = x.narrow(1, 1, m) 70 | log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_() 71 | 72 | loss = - (log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz 73 | 74 | return loss 75 | 76 | 77 | class Embed(nn.Module): 78 | """Embedding module""" 79 | def __init__(self, dim_in=1024, dim_out=128): 80 | super(Embed, self).__init__() 81 | self.linear = nn.Linear(dim_in, dim_out) 82 | self.l2norm = Normalize(2) 83 | 84 | def forward(self, x): 85 | x = x.view(x.shape[0], -1) 86 | x = self.linear(x) 87 | x = self.l2norm(x) 88 | return x 89 | 90 | 91 | class Normalize(nn.Module): 92 | """normalization layer""" 93 | def __init__(self, power=2): 94 | super(Normalize, self).__init__() 95 | self.power = power 96 | 97 | def forward(self, x): 98 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 99 | out = x.div(norm) 100 | return out 101 | -------------------------------------------------------------------------------- /crd/memory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | 5 | 6 | class ContrastMemory(nn.Module): 7 | """ 8 | memory buffer that supplies large amount of negative samples. 9 | """ 10 | def __init__(self, inputSize, outputSize, K, T=0.07, momentum=0.5): 11 | super(ContrastMemory, self).__init__() 12 | self.nLem = outputSize 13 | self.unigrams = torch.ones(self.nLem) 14 | self.multinomial = AliasMethod(self.unigrams) 15 | self.multinomial.cuda() 16 | self.K = K 17 | 18 | self.register_buffer('params', torch.tensor([K, T, -1, -1, momentum])) 19 | stdv = 1. / math.sqrt(inputSize / 3) 20 | self.register_buffer('memory_v1', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv)) 21 | self.register_buffer('memory_v2', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv)) 22 | 23 | def forward(self, v1, v2, y, idx=None): 24 | K = int(self.params[0].item()) 25 | T = self.params[1].item() 26 | Z_v1 = self.params[2].item() 27 | Z_v2 = self.params[3].item() 28 | 29 | momentum = self.params[4].item() 30 | batchSize = v1.size(0) 31 | outputSize = self.memory_v1.size(0) 32 | inputSize = self.memory_v1.size(1) 33 | 34 | # original score computation 35 | if idx is None: 36 | idx = self.multinomial.draw(batchSize * (self.K + 1)).view(batchSize, -1) 37 | idx.select(1, 0).copy_(y.data) 38 | # sample 39 | weight_v1 = torch.index_select(self.memory_v1, 0, idx.view(-1)).detach() 40 | weight_v1 = weight_v1.view(batchSize, K + 1, inputSize) 41 | out_v2 = torch.bmm(weight_v1, v2.view(batchSize, inputSize, 1)) 42 | out_v2 = torch.exp(torch.div(out_v2, T)) 43 | # sample 44 | weight_v2 = torch.index_select(self.memory_v2, 0, idx.view(-1)).detach() 45 | weight_v2 = weight_v2.view(batchSize, K + 1, inputSize) 46 | out_v1 = torch.bmm(weight_v2, v1.view(batchSize, inputSize, 1)) 47 | out_v1 = torch.exp(torch.div(out_v1, T)) 48 | 49 | # set Z if haven't been set yet 50 | if Z_v1 < 0: 51 | self.params[2] = out_v1.mean() * outputSize 52 | Z_v1 = self.params[2].clone().detach().item() 53 | print("normalization constant Z_v1 is set to {:.1f}".format(Z_v1)) 54 | if Z_v2 < 0: 55 | self.params[3] = out_v2.mean() * outputSize 56 | Z_v2 = self.params[3].clone().detach().item() 57 | print("normalization constant Z_v2 is set to {:.1f}".format(Z_v2)) 58 | 59 | # compute out_v1, out_v2 60 | out_v1 = torch.div(out_v1, Z_v1).contiguous() 61 | out_v2 = torch.div(out_v2, Z_v2).contiguous() 62 | 63 | # update memory 64 | with torch.no_grad(): 65 | l_pos = torch.index_select(self.memory_v1, 0, y.view(-1)) 66 | l_pos.mul_(momentum) 67 | l_pos.add_(torch.mul(v1, 1 - momentum)) 68 | l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5) 69 | updated_v1 = l_pos.div(l_norm) 70 | self.memory_v1.index_copy_(0, y, updated_v1) 71 | 72 | ab_pos = torch.index_select(self.memory_v2, 0, y.view(-1)) 73 | ab_pos.mul_(momentum) 74 | ab_pos.add_(torch.mul(v2, 1 - momentum)) 75 | ab_norm = ab_pos.pow(2).sum(1, keepdim=True).pow(0.5) 76 | updated_v2 = ab_pos.div(ab_norm) 77 | self.memory_v2.index_copy_(0, y, updated_v2) 78 | 79 | return out_v1, out_v2 80 | 81 | 82 | class AliasMethod(object): 83 | """ 84 | From: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/ 85 | """ 86 | def __init__(self, probs): 87 | 88 | if probs.sum() > 1: 89 | probs.div_(probs.sum()) 90 | K = len(probs) 91 | self.prob = torch.zeros(K) 92 | self.alias = torch.LongTensor([0]*K) 93 | 94 | # Sort the data into the outcomes with probabilities 95 | # that are larger and smaller than 1/K. 96 | smaller = [] 97 | larger = [] 98 | for kk, prob in enumerate(probs): 99 | self.prob[kk] = K*prob 100 | if self.prob[kk] < 1.0: 101 | smaller.append(kk) 102 | else: 103 | larger.append(kk) 104 | 105 | # Loop though and create little binary mixtures that 106 | # appropriately allocate the larger outcomes over the 107 | # overall uniform mixture. 108 | while len(smaller) > 0 and len(larger) > 0: 109 | small = smaller.pop() 110 | large = larger.pop() 111 | 112 | self.alias[small] = large 113 | self.prob[large] = (self.prob[large] - 1.0) + self.prob[small] 114 | 115 | if self.prob[large] < 1.0: 116 | smaller.append(large) 117 | else: 118 | larger.append(large) 119 | 120 | for last_one in smaller+larger: 121 | self.prob[last_one] = 1 122 | 123 | def cuda(self): 124 | self.prob = self.prob.cuda() 125 | self.alias = self.alias.cuda() 126 | 127 | def draw(self, N): 128 | """ Draw N samples from multinomial """ 129 | K = self.alias.size(0) 130 | 131 | kk = torch.zeros(N, dtype=torch.long, device=self.prob.device).random_(0, K) 132 | prob = self.prob.index_select(0, kk) 133 | alias = self.alias.index_select(0, kk) 134 | # b is whether a random number is greater than q 135 | b = torch.bernoulli(prob) 136 | oq = kk.mul(b.long()) 137 | oj = alias.mul((1-b).long()) 138 | 139 | return oq + oj -------------------------------------------------------------------------------- /dataset/__pycache__/cifar100.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/dataset/__pycache__/cifar100.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/cifar100.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/dataset/__pycache__/cifar100.cpython-39.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/imagenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/dataset/__pycache__/imagenet.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/imagenet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/dataset/__pycache__/imagenet.cpython-39.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/imagenet_dali.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/dataset/__pycache__/imagenet_dali.cpython-36.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/imagenet_dali.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/dataset/__pycache__/imagenet_dali.cpython-39.pyc -------------------------------------------------------------------------------- /dataset/base.py: -------------------------------------------------------------------------------- 1 | # https://github.com/tanglang96/DataLoaders_DALI/blob/master/base.py 2 | 3 | from nvidia.dali.plugin.pytorch import DALIGenericIterator 4 | 5 | class DALIDataloader(DALIGenericIterator): 6 | def __init__(self, pipeline, size, batch_size, output_map=["data", "label"], auto_reset=True, onehot_label=False): 7 | self.size = size 8 | self.batch_size = batch_size 9 | self.onehot_label = onehot_label 10 | self.output_map = output_map 11 | super().__init__(pipelines=pipeline, size=size, auto_reset=auto_reset, output_map=output_map) 12 | 13 | def __next__(self): 14 | if self._first_batch is not None: 15 | batch = self._first_batch 16 | self._first_batch = None 17 | return batch 18 | data = super().__next__()[0] 19 | if self.onehot_label: 20 | return [data[self.output_map[0]], data[self.output_map[1]].squeeze().long()] 21 | else: 22 | return [data[self.output_map[0]], data[self.output_map[1]]] 23 | 24 | def __len__(self): 25 | if self.size%self.batch_size==0: 26 | return self.size//self.batch_size 27 | else: 28 | return self.size//self.batch_size+1 29 | -------------------------------------------------------------------------------- /dataset/cifar100.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import numpy as np 5 | from torch.utils.data import DataLoader 6 | from torchvision import datasets, transforms 7 | from PIL import Image 8 | 9 | """ 10 | mean = { 11 | 'cifar100': (0.5071, 0.4867, 0.4408), 12 | } 13 | 14 | std = { 15 | 'cifar100': (0.2675, 0.2565, 0.2761), 16 | } 17 | """ 18 | 19 | 20 | def get_data_folder(): 21 | """ 22 | return the path to store the data 23 | """ 24 | data_folder = '../data/' 25 | 26 | if not os.path.isdir(data_folder): 27 | os.makedirs(data_folder) 28 | 29 | return data_folder 30 | 31 | class CIFAR100BackCompat(datasets.CIFAR100): 32 | """ 33 | CIFAR100Instance+Sample Dataset 34 | """ 35 | 36 | @property 37 | def train_labels(self): 38 | return self.targets 39 | 40 | @property 41 | def test_labels(self): 42 | return self.targets 43 | 44 | @property 45 | def train_data(self): 46 | return self.data 47 | 48 | @property 49 | def test_data(self): 50 | return self.data 51 | 52 | class CIFAR100Instance(CIFAR100BackCompat): 53 | """CIFAR100Instance Dataset. 54 | """ 55 | def __getitem__(self, index): 56 | 57 | img, target = self.data[index], self.targets[index] 58 | 59 | # doing this so that it is consistent with all other datasets 60 | # to return a PIL Image 61 | img = Image.fromarray(img) 62 | 63 | if self.transform is not None: 64 | img = self.transform(img) 65 | 66 | if self.target_transform is not None: 67 | target = self.target_transform(target) 68 | 69 | return img, target, index 70 | 71 | 72 | def get_cifar100_dataloaders(batch_size=128, num_workers=8, is_instance=False): 73 | """ 74 | cifar 100 75 | """ 76 | data_folder = get_data_folder() 77 | 78 | train_transform = transforms.Compose([ 79 | transforms.RandomCrop(32, padding=4), 80 | transforms.RandomHorizontalFlip(), 81 | transforms.ToTensor(), 82 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 83 | ]) 84 | test_transform = transforms.Compose([ 85 | transforms.ToTensor(), 86 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 87 | ]) 88 | 89 | if is_instance: 90 | train_set = CIFAR100Instance(root=data_folder, 91 | download=True, 92 | train=True, 93 | transform=train_transform) 94 | n_data = len(train_set) 95 | else: 96 | train_set = datasets.CIFAR100(root=data_folder, 97 | download=True, 98 | train=True, 99 | transform=train_transform) 100 | train_loader = DataLoader(train_set, 101 | batch_size=batch_size, 102 | shuffle=True, 103 | num_workers=num_workers) 104 | 105 | test_set = datasets.CIFAR100(root=data_folder, 106 | download=True, 107 | train=False, 108 | transform=test_transform) 109 | test_loader = DataLoader(test_set, 110 | batch_size=int(batch_size/2), 111 | shuffle=False, 112 | num_workers=int(num_workers/2)) 113 | 114 | if is_instance: 115 | return train_loader, test_loader, n_data 116 | else: 117 | return train_loader, test_loader 118 | 119 | 120 | class CIFAR100InstanceSample(CIFAR100BackCompat): 121 | """ 122 | CIFAR100Instance+Sample Dataset 123 | """ 124 | def __init__(self, root, train=True, 125 | transform=None, target_transform=None, 126 | download=False, k=4096, mode='exact', is_sample=True, percent=1.0): 127 | super().__init__(root=root, train=train, download=download, 128 | transform=transform, target_transform=target_transform) 129 | self.k = k 130 | self.mode = mode 131 | self.is_sample = is_sample 132 | 133 | num_classes = 100 134 | num_samples = len(self.data) 135 | label = self.targets 136 | 137 | self.cls_positive = [[] for i in range(num_classes)] 138 | for i in range(num_samples): 139 | self.cls_positive[label[i]].append(i) 140 | 141 | self.cls_negative = [[] for i in range(num_classes)] 142 | for i in range(num_classes): 143 | for j in range(num_classes): 144 | if j == i: 145 | continue 146 | self.cls_negative[i].extend(self.cls_positive[j]) 147 | 148 | self.cls_positive = [np.asarray(self.cls_positive[i]) for i in range(num_classes)] 149 | self.cls_negative = [np.asarray(self.cls_negative[i]) for i in range(num_classes)] 150 | 151 | if 0 < percent < 1: 152 | n = int(len(self.cls_negative[0]) * percent) 153 | self.cls_negative = [np.random.permutation(self.cls_negative[i])[0:n] 154 | for i in range(num_classes)] 155 | 156 | self.cls_positive = np.asarray(self.cls_positive) 157 | self.cls_negative = np.asarray(self.cls_negative) 158 | 159 | def __getitem__(self, index): 160 | 161 | img, target = self.data[index], self.targets[index] 162 | 163 | # doing this so that it is consistent with all other datasets 164 | # to return a PIL Image 165 | img = Image.fromarray(img) 166 | 167 | if self.transform is not None: 168 | img = self.transform(img) 169 | 170 | if self.target_transform is not None: 171 | target = self.target_transform(target) 172 | 173 | if not self.is_sample: 174 | # directly return 175 | return img, target, index 176 | else: 177 | # sample contrastive examples 178 | if self.mode == 'exact': 179 | pos_idx = index 180 | elif self.mode == 'relax': 181 | pos_idx = np.random.choice(self.cls_positive[target], 1) 182 | pos_idx = pos_idx[0] 183 | else: 184 | raise NotImplementedError(self.mode) 185 | replace = True if self.k > len(self.cls_negative[target]) else False 186 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace) 187 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx)) 188 | return img, target, index, sample_idx 189 | 190 | def get_cifar100_dataloaders_sample(batch_size=128, num_workers=8, k=4096, mode='exact', 191 | is_sample=True, percent=1.0): 192 | """ 193 | cifar 100 194 | """ 195 | data_folder = get_data_folder() 196 | 197 | train_transform = transforms.Compose([ 198 | transforms.RandomCrop(32, padding=4), 199 | transforms.RandomHorizontalFlip(), 200 | transforms.ToTensor(), 201 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 202 | ]) 203 | test_transform = transforms.Compose([ 204 | transforms.ToTensor(), 205 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), 206 | ]) 207 | 208 | train_set = CIFAR100InstanceSample(root=data_folder, 209 | download=True, 210 | train=True, 211 | transform=train_transform, 212 | k=k, 213 | mode=mode, 214 | is_sample=is_sample, 215 | percent=percent) 216 | n_data = len(train_set) 217 | train_loader = DataLoader(train_set, 218 | batch_size=batch_size, 219 | shuffle=True, 220 | num_workers=num_workers) 221 | 222 | test_set = datasets.CIFAR100(root=data_folder, 223 | download=True, 224 | train=False, 225 | transform=test_transform) 226 | test_loader = DataLoader(test_set, 227 | batch_size=int(batch_size/2), 228 | shuffle=False, 229 | num_workers=int(num_workers/2)) 230 | 231 | return train_loader, test_loader, n_data 232 | -------------------------------------------------------------------------------- /dataset/imagenet.py: -------------------------------------------------------------------------------- 1 | """ 2 | get data loaders 3 | """ 4 | from __future__ import print_function 5 | 6 | import os 7 | import numpy as np 8 | from torch.utils.data import DataLoader 9 | from torch.utils.data.distributed import DistributedSampler 10 | from torchvision import datasets 11 | from torchvision import transforms 12 | 13 | def get_data_folder(dataset='imagenet'): 14 | """ 15 | return the path to store the data 16 | """ 17 | data_folder = os.path.join('/home/cdf/RepDistiller-master/RepDistiller-master/data', dataset) 18 | 19 | if not os.path.isdir(data_folder): 20 | os.makedirs(data_folder) 21 | 22 | return data_folder 23 | 24 | 25 | class ImageFolderInstance(datasets.ImageFolder): 26 | """: Folder datasets which returns the index of the image as well:: 27 | """ 28 | def __getitem__(self, index): 29 | """ 30 | Args: 31 | index (int): Index 32 | Returns: 33 | tuple: (image, target) where target is class_index of the target class. 34 | """ 35 | path, target = self.imgs[index] 36 | img = self.loader(path) 37 | if self.transform is not None: 38 | img = self.transform(img) 39 | if self.target_transform is not None: 40 | target = self.target_transform(target) 41 | 42 | return img, target, index 43 | 44 | 45 | class ImageFolderSample(datasets.ImageFolder): 46 | """: Folder datasets which returns (img, label, index, contrast_index): 47 | """ 48 | def __init__(self, root, transform=None, target_transform=None, 49 | is_sample=False, k=4096): 50 | super().__init__(root=root, transform=transform, target_transform=target_transform) 51 | 52 | self.k = k 53 | self.is_sample = is_sample 54 | 55 | print('stage1 finished!') 56 | 57 | if self.is_sample: 58 | num_classes = len(self.classes) 59 | num_samples = len(self.samples) 60 | label = np.zeros(num_samples, dtype=np.int32) 61 | for i in range(num_samples): 62 | path, target = self.imgs[i] 63 | label[i] = target 64 | 65 | self.cls_positive = [[] for i in range(num_classes)] 66 | for i in range(num_samples): 67 | self.cls_positive[label[i]].append(i) 68 | 69 | self.cls_negative = [[] for i in range(num_classes)] 70 | for i in range(num_classes): 71 | for j in range(num_classes): 72 | if j == i: 73 | continue 74 | self.cls_negative[i].extend(self.cls_positive[j]) 75 | 76 | self.cls_positive = [np.asarray(self.cls_positive[i], dtype=np.int32) for i in range(num_classes)] 77 | self.cls_negative = [np.asarray(self.cls_negative[i], dtype=np.int32) for i in range(num_classes)] 78 | 79 | print('dataset initialized!') 80 | 81 | def __getitem__(self, index): 82 | """ 83 | Args: 84 | index (int): Index 85 | Returns: 86 | tuple: (image, target) where target is class_index of the target class. 87 | """ 88 | path, target = self.imgs[index] 89 | img = self.loader(path) 90 | if self.transform is not None: 91 | img = self.transform(img) 92 | if self.target_transform is not None: 93 | target = self.target_transform(target) 94 | 95 | if self.is_sample: 96 | # sample contrastive examples 97 | pos_idx = index 98 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=True) 99 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx)) 100 | # sample_idx = np.hstack((neg_idx)) 101 | return img, target, index, sample_idx 102 | else: 103 | return img, target, index 104 | 105 | 106 | def get_test_loader(dataset='imagenet', batch_size=128, num_workers=8): 107 | """get the test data loader""" 108 | 109 | if dataset == 'imagenet': 110 | data_folder = get_data_folder(dataset) 111 | else: 112 | raise NotImplementedError('dataset not supported: {}'.format(dataset)) 113 | 114 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 115 | std=[0.229, 0.224, 0.225]) 116 | test_transform = transforms.Compose([ 117 | transforms.Resize(256), 118 | transforms.CenterCrop(224), 119 | transforms.ToTensor(), 120 | normalize, 121 | ]) 122 | 123 | test_folder = os.path.join(data_folder, 'val') 124 | test_set = datasets.ImageFolder(test_folder, transform=test_transform) 125 | test_loader = DataLoader(test_set, 126 | batch_size=batch_size, 127 | shuffle=False, 128 | num_workers=num_workers, 129 | pin_memory=True) 130 | 131 | return test_loader 132 | 133 | 134 | def get_dataloader_sample(dataset='imagenet', batch_size=128, num_workers=8, 135 | is_sample=False, k=4096, multiprocessing_distributed=False): 136 | """Data Loader for ImageNet""" 137 | 138 | if dataset == 'imagenet': 139 | data_folder = get_data_folder(dataset) 140 | else: 141 | raise NotImplementedError('dataset not supported: {}'.format(dataset)) 142 | 143 | # add data transform 144 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 145 | std=[0.229, 0.224, 0.225]) 146 | train_transform = transforms.Compose([ 147 | transforms.RandomResizedCrop(224), 148 | transforms.RandomHorizontalFlip(), 149 | transforms.ToTensor(), 150 | normalize, 151 | ]) 152 | test_transform = transforms.Compose([ 153 | transforms.Resize(256), 154 | transforms.CenterCrop(224), 155 | transforms.ToTensor(), 156 | normalize, 157 | ]) 158 | train_folder = os.path.join(data_folder, 'train') 159 | test_folder = os.path.join(data_folder, 'val') 160 | 161 | train_set = ImageFolderSample(train_folder, transform=train_transform, is_sample=is_sample, k=k) 162 | test_set = datasets.ImageFolder(test_folder, transform=test_transform) 163 | 164 | if multiprocessing_distributed: 165 | train_sampler = DistributedSampler(train_set) 166 | test_sampler = DistributedSampler(test_set, shuffle=False) 167 | else: 168 | train_sampler = None 169 | test_sampler = None 170 | 171 | train_loader = DataLoader(train_set, 172 | batch_size=batch_size, 173 | shuffle=(train_sampler is None), 174 | num_workers=num_workers, 175 | pin_memory=True, 176 | sampler=train_sampler) 177 | test_loader = DataLoader(test_set, 178 | batch_size=batch_size, 179 | shuffle=False, 180 | num_workers=num_workers, 181 | pin_memory=True, 182 | sampler=test_sampler) 183 | 184 | print('num_samples', len(train_set.samples)) 185 | print('num_class', len(train_set.classes)) 186 | 187 | return train_loader, test_loader, len(train_set), len(train_set.classes), train_sampler 188 | 189 | 190 | def get_imagenet_dataloader(dataset='imagenet', batch_size=128, num_workers=16, 191 | multiprocessing_distributed=False): 192 | """ 193 | Data Loader for imagenet 194 | """ 195 | if dataset == 'imagenet': 196 | data_folder = get_data_folder(dataset) 197 | else: 198 | raise NotImplementedError('dataset not supported: {}'.format(dataset)) 199 | 200 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 201 | std=[0.229, 0.224, 0.225]) 202 | train_transform = transforms.Compose([ 203 | transforms.RandomResizedCrop(224), 204 | transforms.RandomHorizontalFlip(), 205 | transforms.ToTensor(), 206 | normalize, 207 | ]) 208 | test_transform = transforms.Compose([ 209 | transforms.Resize(256), 210 | transforms.CenterCrop(224), 211 | transforms.ToTensor(), 212 | normalize, 213 | ]) 214 | 215 | train_folder = os.path.join(data_folder, 'train') 216 | test_folder = os.path.join(data_folder, 'val') 217 | 218 | train_set = datasets.ImageFolder(train_folder, transform=train_transform) 219 | test_set = datasets.ImageFolder(test_folder, transform=test_transform) 220 | 221 | if multiprocessing_distributed: 222 | train_sampler = DistributedSampler(train_set) 223 | test_sampler = DistributedSampler(test_set, shuffle=False) 224 | else: 225 | train_sampler = None 226 | test_sampler = None 227 | 228 | train_loader = DataLoader(train_set, 229 | batch_size=batch_size, 230 | shuffle=(train_sampler is None), 231 | num_workers=num_workers, 232 | pin_memory=True, 233 | sampler=train_sampler) 234 | 235 | test_loader = DataLoader(test_set, 236 | batch_size=batch_size, 237 | shuffle=False, 238 | num_workers=num_workers, 239 | pin_memory=True, 240 | sampler=test_sampler) 241 | 242 | return train_loader, test_loader, train_sampler 243 | -------------------------------------------------------------------------------- /dataset/imagenet_dali.py: -------------------------------------------------------------------------------- 1 | # https://github.com/NVIDIA/DALI/blob/master/docs/examples/use_cases/pytorch/resnet50/main.py 2 | 3 | import argparse 4 | import os 5 | import shutil 6 | import time 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.parallel 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.optim 15 | import torch.utils.data 16 | import torch.utils.data.distributed 17 | import torchvision.transforms as transforms 18 | import torchvision.datasets as datasets 19 | import torchvision.models as models 20 | 21 | import numpy as np 22 | 23 | try: 24 | from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy 25 | from nvidia.dali.pipeline import pipeline_def 26 | import nvidia.dali.types as types 27 | import nvidia.dali.fn as fn 28 | except ImportError: 29 | raise ImportError("Please install DALI from https://www.github.com/NVIDIA/DALI to run this example.") 30 | 31 | @pipeline_def 32 | def create_dali_pipeline(data_dir, crop, size, shard_id, num_shards, dali_cpu=False, is_training=True): 33 | images, labels = fn.readers.file(file_root=data_dir, 34 | shard_id=shard_id, 35 | num_shards=num_shards, 36 | random_shuffle=is_training, 37 | pad_last_batch=True, 38 | name="Reader") 39 | dali_device = 'cpu' if dali_cpu else 'gpu' 40 | decoder_device = 'cpu' if dali_cpu else 'mixed' 41 | device_memory_padding = 211025920 if decoder_device == 'mixed' else 0 42 | host_memory_padding = 140544512 if decoder_device == 'mixed' else 0 43 | if is_training: 44 | images = fn.decoders.image_random_crop(images, 45 | device=decoder_device, output_type=types.RGB, 46 | device_memory_padding=device_memory_padding, 47 | host_memory_padding=host_memory_padding, 48 | random_aspect_ratio=[0.8, 1.25], 49 | random_area=[0.1, 1.0], 50 | num_attempts=100) 51 | images = fn.resize(images, 52 | device=dali_device, 53 | resize_x=crop, 54 | resize_y=crop, 55 | interp_type=types.INTERP_TRIANGULAR) 56 | mirror = fn.random.coin_flip(probability=0.5) 57 | else: 58 | images = fn.decoders.image(images, 59 | device=decoder_device, 60 | output_type=types.RGB) 61 | images = fn.resize(images, 62 | device=dali_device, 63 | size=size, 64 | mode="not_smaller", 65 | interp_type=types.INTERP_TRIANGULAR) 66 | mirror = False 67 | 68 | images = fn.crop_mirror_normalize(images.gpu(), 69 | dtype=types.FLOAT, 70 | output_layout="CHW", 71 | crop=(crop, crop), 72 | mean=[0.485 * 255,0.456 * 255,0.406 * 255], 73 | std=[0.229 * 255,0.224 * 255,0.225 * 255], 74 | mirror=mirror) 75 | labels = labels.gpu() 76 | return images, labels 77 | 78 | def get_dali_data_loader(args): 79 | crop_size = 224 80 | val_size = 256 81 | 82 | path = '../data' 83 | data_folder = os.path.join(path, args.dataset) 84 | if not os.path.isdir(data_folder): 85 | print('Please place the ImageNet dataset at: ', path) 86 | 87 | traindir = os.path.join(data_folder, 'train') 88 | valdir = os.path.join(data_folder, 'val') 89 | 90 | pipe = create_dali_pipeline(batch_size=args.batch_size, 91 | num_threads=args.num_workers, 92 | device_id=args.rank, 93 | seed=12 + args.rank, 94 | data_dir=traindir, 95 | crop=crop_size, 96 | size=val_size, 97 | dali_cpu=args.dali == 'cpu', 98 | shard_id=args.rank, 99 | num_shards=args.world_size, 100 | is_training=True) 101 | pipe.build() 102 | train_loader = DALIClassificationIterator(pipe, reader_name="Reader", last_batch_policy=LastBatchPolicy.PARTIAL) 103 | 104 | pipe = create_dali_pipeline(batch_size=args.batch_size, 105 | num_threads=args.num_workers, 106 | device_id=args.rank, 107 | seed=12 + args.rank, 108 | data_dir=valdir, 109 | crop=crop_size, 110 | size=val_size, 111 | dali_cpu=args.dali == 'cpu', 112 | shard_id=args.rank, 113 | num_shards=args.world_size, 114 | is_training=False) 115 | pipe.build() 116 | val_loader = DALIClassificationIterator(pipe, reader_name="Reader", last_batch_policy=LastBatchPolicy.PARTIAL) 117 | 118 | return train_loader, val_loader -------------------------------------------------------------------------------- /distiller_zoo/AT.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Attention(nn.Module): 8 | """Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks 9 | via Attention Transfer 10 | code: https://github.com/szagoruyko/attention-transfer""" 11 | def __init__(self, p=2): 12 | super(Attention, self).__init__() 13 | self.p = p 14 | 15 | def forward(self, g_s, g_t): 16 | # only calculate min(len(g_s), len(g_t))-pair at_loss with the help of zip function 17 | return [self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)] 18 | 19 | def at_loss(self, f_s, f_t): 20 | s_H, t_H = f_s.shape[2], f_t.shape[2] 21 | if s_H > t_H: 22 | f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H)) 23 | elif s_H < t_H: 24 | f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H)) 25 | else: 26 | pass 27 | return (self.at(f_s) - self.at(f_t)).pow(2).mean() 28 | 29 | def at(self, f): 30 | # mean(1) function reduce feature map BxCxHxW into BxHxW by averaging the channel response 31 | return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1)) 32 | -------------------------------------------------------------------------------- /distiller_zoo/FitNet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class HintLoss(nn.Module): 7 | """Fitnets: hints for thin deep nets, ICLR 2015""" 8 | def __init__(self): 9 | super(HintLoss, self).__init__() 10 | self.crit = nn.MSELoss() 11 | 12 | def forward(self, f_s, f_t): 13 | loss = self.crit(f_s, f_t) 14 | return loss 15 | -------------------------------------------------------------------------------- /distiller_zoo/KD.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class DistillKL(nn.Module): 8 | """Distilling the Knowledge in a Neural Network""" 9 | def __init__(self, T): 10 | super(DistillKL, self).__init__() 11 | self.T = T 12 | 13 | def forward(self, y_s, y_t): 14 | p_s = F.log_softmax(y_s/self.T, dim=1) 15 | p_t = F.softmax(y_t/self.T, dim=1) 16 | loss = nn.KLDivLoss(reduction='batchmean')(p_s, p_t) * (self.T**2) 17 | return loss 18 | -------------------------------------------------------------------------------- /distiller_zoo/SP.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class Similarity(nn.Module): 9 | """Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author""" 10 | def __init__(self): 11 | super(Similarity, self).__init__() 12 | 13 | def forward(self, g_s, g_t): 14 | return [self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)] 15 | 16 | def similarity_loss(self, f_s, f_t): 17 | bsz = f_s.shape[0] 18 | f_s = f_s.view(bsz, -1) 19 | f_t = f_t.view(bsz, -1) 20 | 21 | G_s = torch.mm(f_s, torch.t(f_s)) 22 | # G_s = G_s / G_s.norm(2) 23 | G_s = torch.nn.functional.normalize(G_s, dim=1) 24 | G_t = torch.mm(f_t, torch.t(f_t)) 25 | # G_t = G_t / G_t.norm(2) 26 | G_t = torch.nn.functional.normalize(G_t, dim=1) 27 | 28 | G_diff = G_t - G_s 29 | loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz) 30 | return loss 31 | -------------------------------------------------------------------------------- /distiller_zoo/SemCKD.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class SemCKDLoss(nn.Module): 9 | """Cross-Layer Distillation with Semantic Calibration, AAAI2021""" 10 | def __init__(self): 11 | super(SemCKDLoss, self).__init__() 12 | self.crit = nn.MSELoss(reduction='none') 13 | 14 | def forward(self, s_value, f_target, weight): 15 | bsz, num_stu, num_tea = weight.shape 16 | ind_loss = torch.zeros(bsz, num_stu, num_tea).cuda() 17 | 18 | for i in range(num_stu): 19 | for j in range(num_tea): 20 | ind_loss[:, i, j] = self.crit(s_value[i][j], f_target[i][j]).reshape(bsz,-1).mean(-1) 21 | 22 | loss = (weight * ind_loss).sum()/(1.0*bsz*num_stu) 23 | return loss -------------------------------------------------------------------------------- /distiller_zoo/VID.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | 9 | class VIDLoss(nn.Module): 10 | """Variational Information Distillation for Knowledge Transfer (CVPR 2019), 11 | code from author: https://github.com/ssahn0215/variational-information-distillation""" 12 | def __init__(self, 13 | num_input_channels, 14 | num_mid_channel, 15 | num_target_channels, 16 | init_pred_var=5.0, 17 | eps=1e-5): 18 | super(VIDLoss, self).__init__() 19 | 20 | def conv1x1(in_channels, out_channels, stride=1): 21 | return nn.Conv2d( 22 | in_channels, out_channels, 23 | kernel_size=1, padding=0, 24 | bias=False, stride=stride) 25 | 26 | self.regressor = nn.Sequential( 27 | conv1x1(num_input_channels, num_mid_channel), 28 | nn.ReLU(), 29 | conv1x1(num_mid_channel, num_mid_channel), 30 | nn.ReLU(), 31 | conv1x1(num_mid_channel, num_target_channels), 32 | ) 33 | self.log_scale = torch.nn.Parameter( 34 | np.log(np.exp(init_pred_var-eps)-1.0) * torch.ones(num_target_channels) 35 | ) 36 | self.eps = eps 37 | 38 | def forward(self, input, target): 39 | # pool for dimentsion match 40 | s_H, t_H = input.shape[2], target.shape[2] 41 | if s_H > t_H: 42 | input = F.adaptive_avg_pool2d(input, (t_H, t_H)) 43 | elif s_H < t_H: 44 | target = F.adaptive_avg_pool2d(target, (s_H, s_H)) 45 | else: 46 | pass 47 | pred_mean = self.regressor(input) 48 | pred_var = torch.log(1.0+torch.exp(self.log_scale))+self.eps 49 | pred_var = pred_var.view(1, -1, 1, 1) 50 | neg_log_prob = 0.5*( 51 | (pred_mean-target)**2/pred_var+torch.log(pred_var) 52 | ) 53 | loss = torch.mean(neg_log_prob) 54 | return loss 55 | -------------------------------------------------------------------------------- /distiller_zoo/__init__.py: -------------------------------------------------------------------------------- 1 | from .FitNet import HintLoss 2 | from .AT import Attention 3 | from .KD import DistillKL 4 | from .SP import Similarity 5 | from .VID import VIDLoss 6 | from .SemCKD import SemCKDLoss 7 | -------------------------------------------------------------------------------- /distiller_zoo/__pycache__/AT.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/AT.cpython-36.pyc -------------------------------------------------------------------------------- /distiller_zoo/__pycache__/AT.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/AT.cpython-39.pyc -------------------------------------------------------------------------------- /distiller_zoo/__pycache__/FitNet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/FitNet.cpython-36.pyc -------------------------------------------------------------------------------- /distiller_zoo/__pycache__/FitNet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/FitNet.cpython-39.pyc -------------------------------------------------------------------------------- /distiller_zoo/__pycache__/KD.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/KD.cpython-36.pyc -------------------------------------------------------------------------------- /distiller_zoo/__pycache__/KD.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/KD.cpython-39.pyc -------------------------------------------------------------------------------- /distiller_zoo/__pycache__/SP.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/SP.cpython-36.pyc -------------------------------------------------------------------------------- /distiller_zoo/__pycache__/SP.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/SP.cpython-39.pyc -------------------------------------------------------------------------------- /distiller_zoo/__pycache__/SemCKD.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/SemCKD.cpython-36.pyc -------------------------------------------------------------------------------- /distiller_zoo/__pycache__/SemCKD.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/SemCKD.cpython-39.pyc -------------------------------------------------------------------------------- /distiller_zoo/__pycache__/VID.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/VID.cpython-36.pyc -------------------------------------------------------------------------------- /distiller_zoo/__pycache__/VID.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/VID.cpython-39.pyc -------------------------------------------------------------------------------- /distiller_zoo/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /distiller_zoo/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /helper/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/helper/__init__.py -------------------------------------------------------------------------------- /helper/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/helper/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /helper/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/helper/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /helper/__pycache__/loops.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/helper/__pycache__/loops.cpython-36.pyc -------------------------------------------------------------------------------- /helper/__pycache__/loops.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/helper/__pycache__/loops.cpython-39.pyc -------------------------------------------------------------------------------- /helper/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/helper/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /helper/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/helper/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /helper/loops.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | from cProfile import label 3 | 4 | import sys 5 | import time 6 | import torch 7 | from .util import AverageMeter, accuracy, reduce_tensor 8 | 9 | def train_vanilla(epoch, train_loader, model, criterion, optimizer, opt): 10 | """vanilla training""" 11 | model.train() 12 | 13 | batch_time = AverageMeter() 14 | losses = AverageMeter() 15 | top1 = AverageMeter() 16 | top5 = AverageMeter() 17 | 18 | n_batch = len(train_loader) if opt.dali is None else (train_loader._size + opt.batch_size - 1) // opt.batch_size 19 | 20 | end = time.time() 21 | for idx, batch_data in enumerate(train_loader): 22 | if opt.dali is None: 23 | images, labels = batch_data 24 | else: 25 | images, labels = batch_data[0]['data'], batch_data[0]['label'].squeeze().long() 26 | 27 | if opt.gpu is not None: 28 | images = images.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True) 29 | if torch.cuda.is_available(): 30 | labels = labels.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True) 31 | 32 | # ===================forward===================== 33 | output = model(images) 34 | loss = criterion(output, labels) 35 | losses.update(loss.item(), images.size(0)) 36 | 37 | # ===================Metrics===================== 38 | metrics = accuracy(output, labels, topk=(1, 5)) 39 | top1.update(metrics[0].item(), images.size(0)) 40 | top5.update(metrics[1].item(), images.size(0)) 41 | batch_time.update(time.time() - end) 42 | 43 | # ===================backward===================== 44 | optimizer.zero_grad() 45 | loss.backward() 46 | 47 | optimizer.step() 48 | 49 | # print info 50 | if idx % opt.print_freq == 0: 51 | print('Epoch: [{0}][{1}/{2}]\t' 52 | 'GPU {3}\t' 53 | 'Time: {batch_time.avg:.3f}\t' 54 | 'Loss {loss.avg:.4f}\t' 55 | 'Acc@1 {top1.avg:.3f}\t' 56 | 'Acc@5 {top5.avg:.3f}'.format( 57 | epoch, idx, n_batch, opt.gpu, batch_time=batch_time, 58 | loss=losses, top1=top1, top5=top5)) 59 | sys.stdout.flush() 60 | 61 | return top1.avg, top5.avg, losses.avg 62 | 63 | def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, opt): 64 | """one epoch distillation""" 65 | # set modules as train() 66 | for module in module_list: 67 | module.train() 68 | # set teacher as eval() 69 | module_list[-1].eval() 70 | 71 | criterion_cls = criterion_list[0] 72 | criterion_div = criterion_list[1] 73 | criterion_kd = criterion_list[2] 74 | 75 | model_s = module_list[0] 76 | model_t = module_list[-1] 77 | 78 | batch_time = AverageMeter() 79 | losses = AverageMeter() 80 | top1 = AverageMeter() 81 | top5 = AverageMeter() 82 | 83 | n_batch = len(train_loader) if opt.dali is None else (train_loader._size + opt.batch_size - 1) // opt.batch_size 84 | 85 | end = time.time() 86 | for idx, data in enumerate(train_loader): 87 | if opt.dali is None: 88 | if opt.distill in ['crd']: 89 | images, labels, index, contrast_idx = data 90 | else: 91 | images, labels = data 92 | else: 93 | images, labels = data[0]['data'], data[0]['label'].squeeze().long() 94 | 95 | if opt.distill == 'semckd' and images.shape[0] < opt.batch_size: 96 | continue 97 | 98 | if opt.gpu is not None: 99 | images = images.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True) 100 | if torch.cuda.is_available(): 101 | labels = labels.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True) 102 | if opt.distill in ['crd']: 103 | index = index.cuda() 104 | contrast_idx = contrast_idx.cuda() 105 | 106 | # ===================forward===================== 107 | feat_s, logit_s = model_s(images, is_feat=True) 108 | with torch.no_grad(): 109 | feat_t, logit_t = model_t(images, is_feat=True) 110 | feat_t = [f.detach() for f in feat_t] 111 | 112 | cls_t = model_t.module.get_feat_modules()[-1] if opt.multiprocessing_distributed else model_t.get_feat_modules()[-1] 113 | 114 | # cls + kl div 115 | loss_cls = criterion_cls(logit_s, labels) 116 | loss_div = criterion_div(logit_s, logit_t) 117 | 118 | # other kd loss 119 | if opt.distill == 'kd': 120 | loss_kd = 0 121 | elif opt.distill == 'hint': 122 | f_s, f_t = module_list[1](feat_s[opt.hint_layer], feat_t[opt.hint_layer]) 123 | loss_kd = criterion_kd(f_s, f_t) 124 | elif opt.distill == 'attention': 125 | # include 1, exclude -1. 126 | g_s = feat_s[1:-1] 127 | g_t = feat_t[1:-1] 128 | loss_group = criterion_kd(g_s, g_t) 129 | loss_kd = sum(loss_group) 130 | elif opt.distill == 'similarity': 131 | g_s = [feat_s[-2]] 132 | g_t = [feat_t[-2]] 133 | loss_group = criterion_kd(g_s, g_t) 134 | loss_kd = sum(loss_group) 135 | elif opt.distill == 'vid': 136 | g_s = feat_s[1:-1] 137 | g_t = feat_t[1:-1] 138 | loss_group = [c(f_s, f_t) for f_s, f_t, c in zip(g_s, g_t, criterion_kd)] 139 | loss_kd = sum(loss_group) 140 | elif opt.distill == 'crd': 141 | f_s = feat_s[-1] 142 | f_t = feat_t[-1] 143 | loss_kd = criterion_kd(f_s, f_t, index, contrast_idx) 144 | elif opt.distill == 'semckd': 145 | s_value, f_target, weight = module_list[1](feat_s[1:-1], feat_t[1:-1]) 146 | loss_kd = criterion_kd(s_value, f_target, weight) 147 | elif opt.distill == 'srrl': 148 | trans_feat_s, pred_feat_s = module_list[1](feat_s[-1], cls_t) 149 | loss_kd = criterion_kd(trans_feat_s, feat_t[-1]) + criterion_kd(pred_feat_s, logit_t) 150 | elif opt.distill == 'simkd': 151 | trans_feat_s, trans_feat_t, pred_feat_s = module_list[1](feat_s[-2], feat_t[-2], cls_t) 152 | logit_s = pred_feat_s 153 | loss_kd = criterion_kd(trans_feat_s, trans_feat_t) 154 | else: 155 | raise NotImplementedError(opt.distill) 156 | 157 | loss = opt.cls * loss_cls + opt.div * loss_div + opt.beta * loss_kd 158 | losses.update(loss.item(), images.size(0)) 159 | 160 | # ===================Metrics===================== 161 | metrics = accuracy(logit_s, labels, topk=(1, 5)) 162 | top1.update(metrics[0].item(), images.size(0)) 163 | top5.update(metrics[1].item(), images.size(0)) 164 | batch_time.update(time.time() - end) 165 | 166 | # ===================backward===================== 167 | optimizer.zero_grad() 168 | loss.backward() 169 | optimizer.step() 170 | 171 | # print info 172 | if idx % opt.print_freq == 0: 173 | print('Epoch: [{0}][{1}/{2}]\t' 174 | 'GPU {3}\t' 175 | 'Time: {batch_time.avg:.3f}\t' 176 | 'Loss {loss.avg:.4f}\t' 177 | 'Acc@1 {top1.avg:.3f}\t' 178 | 'Acc@5 {top5.avg:.3f}'.format( 179 | epoch, idx, n_batch, opt.gpu, loss=losses, top1=top1, top5=top5, 180 | batch_time=batch_time)) 181 | sys.stdout.flush() 182 | 183 | return top1.avg, top5.avg, losses.avg 184 | 185 | def validate_vanilla(val_loader, model, criterion, opt): 186 | """validation""" 187 | 188 | batch_time = AverageMeter() 189 | losses = AverageMeter() 190 | top1 = AverageMeter() 191 | top5 = AverageMeter() 192 | 193 | # switch to evaluate mode 194 | model.eval() 195 | 196 | n_batch = len(val_loader) if opt.dali is None else (val_loader._size + opt.batch_size - 1) // opt.batch_size 197 | 198 | with torch.no_grad(): 199 | end = time.time() 200 | for idx, batch_data in enumerate(val_loader): 201 | 202 | if opt.dali is None: 203 | images, labels = batch_data 204 | else: 205 | images, labels = batch_data[0]['data'], batch_data[0]['label'].squeeze().long() 206 | 207 | if opt.gpu is not None: 208 | images = images.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True) 209 | if torch.cuda.is_available(): 210 | labels = labels.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True) 211 | 212 | # compute output 213 | output = model(images) 214 | loss = criterion(output, labels) 215 | losses.update(loss.item(), images.size(0)) 216 | 217 | # ===================Metrics===================== 218 | metrics = accuracy(output, labels, topk=(1, 5)) 219 | top1.update(metrics[0].item(), images.size(0)) 220 | top5.update(metrics[1].item(), images.size(0)) 221 | batch_time.update(time.time() - end) 222 | 223 | if idx % opt.print_freq == 0: 224 | print('Test: [{0}/{1}]\t' 225 | 'GPU: {2}\t' 226 | 'Time: {batch_time.avg:.3f}\t' 227 | 'Loss {loss.avg:.4f}\t' 228 | 'Acc@1 {top1.avg:.3f}\t' 229 | 'Acc@5 {top5.avg:.3f}'.format( 230 | idx, n_batch, opt.gpu, batch_time=batch_time, loss=losses, 231 | top1=top1, top5=top5)) 232 | 233 | if opt.multiprocessing_distributed: 234 | # Batch size may not be equal across multiple gpus 235 | total_metrics = torch.tensor([top1.sum, top5.sum, losses.sum]).to(opt.gpu) 236 | count_metrics = torch.tensor([top1.count, top5.count, losses.count]).to(opt.gpu) 237 | total_metrics = reduce_tensor(total_metrics, 1) # here world_size=1, because they should be summed up 238 | count_metrics = reduce_tensor(count_metrics, 1) 239 | ret = [] 240 | for s, n in zip(total_metrics.tolist(), count_metrics.tolist()): 241 | ret.append(s / (1.0 * n)) 242 | return ret 243 | 244 | return top1.avg, top5.avg, losses.avg 245 | 246 | 247 | def validate_distill(val_loader, module_list, criterion, opt): 248 | """validation""" 249 | 250 | batch_time = AverageMeter() 251 | losses = AverageMeter() 252 | top1 = AverageMeter() 253 | top5 = AverageMeter() 254 | 255 | # switch to evaluate mode 256 | for module in module_list: 257 | module.eval() 258 | 259 | model_s = module_list[0] 260 | model_t = module_list[-1] 261 | n_batch = len(val_loader) if opt.dali is None else (val_loader._size + opt.batch_size - 1) // opt.batch_size 262 | 263 | with torch.no_grad(): 264 | end = time.time() 265 | for idx, batch_data in enumerate(val_loader): 266 | 267 | if opt.dali is None: 268 | images, labels = batch_data 269 | else: 270 | images, labels = batch_data[0]['data'], batch_data[0]['label'].squeeze().long() 271 | 272 | if opt.gpu is not None: 273 | images = images.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True) 274 | if torch.cuda.is_available(): 275 | labels = labels.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True) 276 | 277 | # compute output 278 | if opt.distill == 'simkd': 279 | feat_s, _ = model_s(images, is_feat=True) 280 | feat_t, _ = model_t(images, is_feat=True) 281 | feat_t = [f.detach() for f in feat_t] 282 | cls_t = model_t.module.get_feat_modules()[-1] if opt.multiprocessing_distributed else model_t.get_feat_modules()[-1] 283 | _, _, output = module_list[1](feat_s[-2], feat_t[-2], cls_t) 284 | else: 285 | output = model_s(images) 286 | loss = criterion(output, labels) 287 | losses.update(loss.item(), images.size(0)) 288 | 289 | # ===================Metrics===================== 290 | metrics = accuracy(output, labels, topk=(1, 5)) 291 | top1.update(metrics[0].item(), images.size(0)) 292 | top5.update(metrics[1].item(), images.size(0)) 293 | batch_time.update(time.time() - end) 294 | 295 | if idx % opt.print_freq == 0: 296 | print('Test: [{0}/{1}]\t' 297 | 'GPU: {2}\t' 298 | 'Time: {batch_time.avg:.3f}\t' 299 | 'Loss {loss.avg:.4f}\t' 300 | 'Acc@1 {top1.avg:.3f}\t' 301 | 'Acc@5 {top5.avg:.3f}'.format( 302 | idx, n_batch, opt.gpu, batch_time=batch_time, loss=losses, 303 | top1=top1, top5=top5)) 304 | 305 | if opt.multiprocessing_distributed: 306 | # Batch size may not be equal across multiple gpus 307 | total_metrics = torch.tensor([top1.sum, top5.sum, losses.sum]).to(opt.gpu) 308 | count_metrics = torch.tensor([top1.count, top5.count, losses.count]).to(opt.gpu) 309 | total_metrics = reduce_tensor(total_metrics, 1) # here world_size=1, because they should be summed up 310 | count_metrics = reduce_tensor(count_metrics, 1) 311 | ret = [] 312 | for s, n in zip(total_metrics.tolist(), count_metrics.tolist()): 313 | ret.append(s / (1.0 * n)) 314 | return ret 315 | 316 | return top1.avg, top5.avg, losses.avg 317 | -------------------------------------------------------------------------------- /helper/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import json 4 | import torch 5 | import numpy as np 6 | import torch.distributed as dist 7 | 8 | def adjust_learning_rate(epoch, opt, optimizer): 9 | """Sets the learning rate to the initial LR decayed by decay rate every steep step""" 10 | steps = np.sum(epoch > np.asarray(opt.lr_decay_epochs)) 11 | if steps > 0: 12 | new_lr = opt.learning_rate * (opt.lr_decay_rate ** steps) 13 | for param_group in optimizer.param_groups: 14 | param_group['lr'] = new_lr 15 | 16 | # def adjust_learning_rate(optimizer, epoch, step, len_epoch, old_lr): 17 | # """Sets the learning rate to the initial LR decayed by decay rate every steep step""" 18 | # if epoch < 5: 19 | # lr = old_lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch) 20 | # elif 5 <= epoch < 60: return 21 | # else: 22 | # factor = epoch // 30 23 | # factor -= 1 24 | # lr = old_lr*(0.1**factor) 25 | 26 | # for param_group in optimizer.param_groups: 27 | # param_group['lr'] = lr 28 | 29 | 30 | class AverageMeter(object): 31 | """Computes and stores the average and current value""" 32 | def __init__(self): 33 | self.reset() 34 | 35 | def reset(self): 36 | self.val = 0 37 | self.avg = 0 38 | self.sum = 0 39 | self.count = 0 40 | 41 | def update(self, val, n=1): 42 | self.val = val 43 | self.sum += val * n 44 | self.count += n 45 | self.avg = self.sum / self.count 46 | 47 | def accuracy(output, target, topk=(1,)): 48 | """Computes the accuracy over the k top predictions for the specified values of k""" 49 | with torch.no_grad(): 50 | maxk = max(topk) 51 | batch_size = target.size(0) 52 | 53 | _, pred = output.topk(maxk, dim = 1, largest = True, sorted = True) 54 | pred = pred.t() 55 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 56 | 57 | res = [] 58 | for k in topk: 59 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 60 | res.append(correct_k.mul_(100.0 / batch_size)) 61 | return res 62 | 63 | def save_dict_to_json(d, json_path): 64 | """Saves dict of floats in json file 65 | 66 | Args: 67 | d: (dict) of float-castable values (np.float, int, float, etc.) 68 | json_path: (string) path to json file 69 | """ 70 | with open(json_path, 'w') as f: 71 | # We need to convert the values to float for json (it doesn't accept np.array, np.float, ) 72 | d = {k: v for k, v in d.items()} 73 | json.dump(d, f, indent=4) 74 | 75 | def load_json_to_dict(json_path): 76 | """Loads json file to dict 77 | 78 | Args: 79 | json_path: (string) path to json file 80 | """ 81 | with open(json_path, 'r') as f: 82 | params = json.load(f) 83 | return params 84 | 85 | def reduce_tensor(tensor, world_size = 1, op='avg'): 86 | rt = tensor.clone() 87 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 88 | if world_size > 1: 89 | rt = torch.true_divide(rt, world_size) 90 | return rt 91 | 92 | if __name__ == '__main__': 93 | 94 | pass 95 | -------------------------------------------------------------------------------- /images/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/images/.DS_Store -------------------------------------------------------------------------------- /images/SimKD_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/images/SimKD_result.png -------------------------------------------------------------------------------- /images/cifar100_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/images/cifar100_result.png -------------------------------------------------------------------------------- /models/ShuffleNetv1.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNet in PyTorch. 2 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details. 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class ShuffleBlock(nn.Module): 10 | def __init__(self, groups): 11 | super(ShuffleBlock, self).__init__() 12 | self.groups = groups 13 | 14 | def forward(self, x): 15 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 16 | N,C,H,W = x.size() 17 | g = self.groups 18 | return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).reshape(N,C,H,W) 19 | 20 | 21 | class Bottleneck(nn.Module): 22 | def __init__(self, in_planes, out_planes, stride, groups, is_last=False): 23 | super(Bottleneck, self).__init__() 24 | self.is_last = is_last 25 | self.stride = stride 26 | 27 | mid_planes = int(out_planes/4) 28 | g = 1 if in_planes == 24 else groups 29 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) 30 | self.bn1 = nn.BatchNorm2d(mid_planes) 31 | self.shuffle1 = ShuffleBlock(groups=g) 32 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) 33 | self.bn2 = nn.BatchNorm2d(mid_planes) 34 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) 35 | self.bn3 = nn.BatchNorm2d(out_planes) 36 | 37 | self.shortcut = nn.Sequential() 38 | if stride == 2: 39 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) 40 | 41 | def forward(self, x): 42 | out = F.relu(self.bn1(self.conv1(x))) 43 | out = self.shuffle1(out) 44 | out = F.relu(self.bn2(self.conv2(out))) 45 | out = self.bn3(self.conv3(out)) 46 | res = self.shortcut(x) 47 | preact = torch.cat([out, res], 1) if self.stride == 2 else out+res 48 | out = F.relu(preact) 49 | # out = F.relu(torch.cat([out, res], 1)) if self.stride == 2 else F.relu(out+res) 50 | if self.is_last: 51 | return out, preact 52 | else: 53 | return out 54 | 55 | 56 | class ShuffleNet(nn.Module): 57 | def __init__(self, cfg, num_classes=10): 58 | super(ShuffleNet, self).__init__() 59 | out_planes = cfg['out_planes'] 60 | num_blocks = cfg['num_blocks'] 61 | groups = cfg['groups'] 62 | 63 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 64 | self.bn1 = nn.BatchNorm2d(24) 65 | self.in_planes = 24 66 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 67 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 68 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 69 | self.linear = nn.Linear(out_planes[2], num_classes) 70 | 71 | def _make_layer(self, out_planes, num_blocks, groups): 72 | layers = [] 73 | for i in range(num_blocks): 74 | stride = 2 if i == 0 else 1 75 | cat_planes = self.in_planes if i == 0 else 0 76 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, 77 | stride=stride, 78 | groups=groups, 79 | is_last=(i == num_blocks - 1))) 80 | self.in_planes = out_planes 81 | return nn.Sequential(*layers) 82 | 83 | def get_feat_modules(self): 84 | feat_m = nn.ModuleList([]) 85 | feat_m.append(self.conv1) 86 | feat_m.append(self.bn1) 87 | feat_m.append(self.layer1) 88 | feat_m.append(self.layer2) 89 | feat_m.append(self.layer3) 90 | return feat_m 91 | 92 | def get_bn_before_relu(self): 93 | raise NotImplementedError('ShuffleNet currently is not supported for "Overhaul" teacher') 94 | 95 | def forward(self, x, is_feat=False, preact=False): 96 | out = F.relu(self.bn1(self.conv1(x))) 97 | f0 = out 98 | out, f1_pre = self.layer1(out) 99 | f1 = out 100 | out, f2_pre = self.layer2(out) 101 | f2 = out 102 | out, f3_pre = self.layer3(out) 103 | f3 = out 104 | out = F.avg_pool2d(out, 4) 105 | out = out.view(out.size(0), -1) 106 | f4 = out 107 | out = self.linear(out) 108 | 109 | if is_feat: 110 | if preact: 111 | return [f0, f1_pre, f2_pre, f3_pre, f4], out 112 | else: 113 | return [f0, f1, f2, f3, f4], out 114 | else: 115 | return out 116 | 117 | 118 | def ShuffleV1(**kwargs): 119 | cfg = { 120 | 'out_planes': [240, 480, 960], 121 | 'num_blocks': [4, 8, 4], 122 | 'groups': 3 123 | } 124 | return ShuffleNet(cfg, **kwargs) 125 | 126 | 127 | if __name__ == '__main__': 128 | 129 | x = torch.randn(2, 3, 32, 32) 130 | net = ShuffleV1(num_classes=100) 131 | import time 132 | a = time.time() 133 | feats, logit = net(x, is_feat=True, preact=True) 134 | b = time.time() 135 | print(b - a) 136 | for f in feats: 137 | print(f.shape, f.min().item()) 138 | print(logit.shape) 139 | -------------------------------------------------------------------------------- /models/ShuffleNetv2.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNetV2 in PyTorch. 2 | See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details. 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class ShuffleBlock(nn.Module): 10 | def __init__(self, groups=2): 11 | super(ShuffleBlock, self).__init__() 12 | self.groups = groups 13 | 14 | def forward(self, x): 15 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 16 | N, C, H, W = x.size() 17 | g = self.groups 18 | return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) 19 | 20 | 21 | class SplitBlock(nn.Module): 22 | def __init__(self, ratio): 23 | super(SplitBlock, self).__init__() 24 | self.ratio = ratio 25 | 26 | def forward(self, x): 27 | c = int(x.size(1) * self.ratio) 28 | return x[:, :c, :, :], x[:, c:, :, :] 29 | 30 | 31 | class BasicBlock(nn.Module): 32 | def __init__(self, in_channels, split_ratio=0.5, is_last=False): 33 | super(BasicBlock, self).__init__() 34 | self.is_last = is_last 35 | self.split = SplitBlock(split_ratio) 36 | in_channels = int(in_channels * split_ratio) 37 | self.conv1 = nn.Conv2d(in_channels, in_channels, 38 | kernel_size=1, bias=False) 39 | self.bn1 = nn.BatchNorm2d(in_channels) 40 | self.conv2 = nn.Conv2d(in_channels, in_channels, 41 | kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False) 42 | self.bn2 = nn.BatchNorm2d(in_channels) 43 | self.conv3 = nn.Conv2d(in_channels, in_channels, 44 | kernel_size=1, bias=False) 45 | self.bn3 = nn.BatchNorm2d(in_channels) 46 | self.shuffle = ShuffleBlock() 47 | 48 | def forward(self, x): 49 | x1, x2 = self.split(x) 50 | out = F.relu(self.bn1(self.conv1(x2))) 51 | out = self.bn2(self.conv2(out)) 52 | preact = self.bn3(self.conv3(out)) 53 | out = F.relu(preact) 54 | # out = F.relu(self.bn3(self.conv3(out))) 55 | preact = torch.cat([x1, preact], 1) 56 | out = torch.cat([x1, out], 1) 57 | out = self.shuffle(out) 58 | if self.is_last: 59 | return out, preact 60 | else: 61 | return out 62 | 63 | 64 | class DownBlock(nn.Module): 65 | def __init__(self, in_channels, out_channels): 66 | super(DownBlock, self).__init__() 67 | mid_channels = out_channels // 2 68 | # left 69 | self.conv1 = nn.Conv2d(in_channels, in_channels, 70 | kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False) 71 | self.bn1 = nn.BatchNorm2d(in_channels) 72 | self.conv2 = nn.Conv2d(in_channels, mid_channels, 73 | kernel_size=1, bias=False) 74 | self.bn2 = nn.BatchNorm2d(mid_channels) 75 | # right 76 | self.conv3 = nn.Conv2d(in_channels, mid_channels, 77 | kernel_size=1, bias=False) 78 | self.bn3 = nn.BatchNorm2d(mid_channels) 79 | self.conv4 = nn.Conv2d(mid_channels, mid_channels, 80 | kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False) 81 | self.bn4 = nn.BatchNorm2d(mid_channels) 82 | self.conv5 = nn.Conv2d(mid_channels, mid_channels, 83 | kernel_size=1, bias=False) 84 | self.bn5 = nn.BatchNorm2d(mid_channels) 85 | 86 | self.shuffle = ShuffleBlock() 87 | 88 | def forward(self, x): 89 | # left 90 | out1 = self.bn1(self.conv1(x)) 91 | out1 = F.relu(self.bn2(self.conv2(out1))) 92 | # right 93 | out2 = F.relu(self.bn3(self.conv3(x))) 94 | out2 = self.bn4(self.conv4(out2)) 95 | out2 = F.relu(self.bn5(self.conv5(out2))) 96 | # concat 97 | out = torch.cat([out1, out2], 1) 98 | out = self.shuffle(out) 99 | return out 100 | 101 | 102 | class ShuffleNetV2(nn.Module): 103 | def __init__(self, net_size, num_classes=10): 104 | super(ShuffleNetV2, self).__init__() 105 | out_channels = configs[net_size]['out_channels'] 106 | num_blocks = configs[net_size]['num_blocks'] 107 | 108 | # self.conv1 = nn.Conv2d(3, 24, kernel_size=3, 109 | # stride=1, padding=1, bias=False) 110 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 111 | self.bn1 = nn.BatchNorm2d(24) 112 | self.in_channels = 24 113 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0]) 114 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1]) 115 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2]) 116 | self.conv2 = nn.Conv2d(out_channels[2], out_channels[3], 117 | kernel_size=1, stride=1, padding=0, bias=False) 118 | self.bn2 = nn.BatchNorm2d(out_channels[3]) 119 | self.linear = nn.Linear(out_channels[3], num_classes) 120 | 121 | def _make_layer(self, out_channels, num_blocks): 122 | layers = [DownBlock(self.in_channels, out_channels)] 123 | for i in range(num_blocks): 124 | layers.append(BasicBlock(out_channels, is_last=(i == num_blocks - 1))) 125 | self.in_channels = out_channels 126 | return nn.Sequential(*layers) 127 | 128 | def get_feat_modules(self): 129 | feat_m = nn.ModuleList([]) 130 | feat_m.append(self.conv1) 131 | feat_m.append(self.bn1) 132 | feat_m.append(self.layer1) 133 | feat_m.append(self.layer2) 134 | feat_m.append(self.layer3) 135 | return feat_m 136 | 137 | def get_bn_before_relu(self): 138 | raise NotImplementedError('ShuffleNetV2 currently is not supported for "Overhaul" teacher') 139 | 140 | def forward(self, x, is_feat=False, preact=False): 141 | out = F.relu(self.bn1(self.conv1(x))) 142 | # out = F.max_pool2d(out, 3, stride=2, padding=1) 143 | f0 = out 144 | out, f1_pre = self.layer1(out) 145 | f1 = out 146 | out, f2_pre = self.layer2(out) 147 | f2 = out 148 | out, f3_pre = self.layer3(out) 149 | f3 = out 150 | out = F.relu(self.bn2(self.conv2(out))) 151 | out = F.avg_pool2d(out, 4) 152 | out = out.view(out.size(0), -1) 153 | f4 = out 154 | out = self.linear(out) 155 | if is_feat: 156 | if preact: 157 | return [f0, f1_pre, f2_pre, f3_pre, f4], out 158 | else: 159 | return [f0, f1, f2, f3, f4], out 160 | else: 161 | return out 162 | 163 | 164 | configs = { 165 | 0.2: { 166 | 'out_channels': (40, 80, 160, 512), 167 | 'num_blocks': (3, 3, 3) 168 | }, 169 | 170 | 0.3: { 171 | 'out_channels': (40, 80, 160, 512), 172 | 'num_blocks': (3, 7, 3) 173 | }, 174 | 175 | 0.5: { 176 | 'out_channels': (48, 96, 192, 1024), 177 | 'num_blocks': (3, 7, 3) 178 | }, 179 | 180 | 1: { 181 | 'out_channels': (116, 232, 464, 1024), 182 | 'num_blocks': (3, 7, 3) 183 | }, 184 | 1.5: { 185 | 'out_channels': (176, 352, 704, 1024), 186 | 'num_blocks': (3, 7, 3) 187 | }, 188 | 2: { 189 | 'out_channels': (224, 488, 976, 2048), 190 | 'num_blocks': (3, 7, 3) 191 | } 192 | } 193 | 194 | 195 | def ShuffleV2_0_2(**kwargs): 196 | model = ShuffleNetV2(net_size=0.2, **kwargs) 197 | return model 198 | 199 | def ShuffleV2_0_5(**kwargs): 200 | model = ShuffleNetV2(net_size=0.5, **kwargs) 201 | return model 202 | 203 | def ShuffleV2(**kwargs): 204 | model = ShuffleNetV2(net_size=1, **kwargs) 205 | return model 206 | 207 | def ShuffleV2_1_5(**kwargs): 208 | model = ShuffleNetV2(net_size=1.5, **kwargs) 209 | return model 210 | 211 | def ShuffleV2_2_0(**kwargs): 212 | model = ShuffleNetV2(net_size=2.0, **kwargs) 213 | return model 214 | 215 | if __name__ == '__main__': 216 | net = ShuffleV2(num_classes=100) 217 | x = torch.randn(3, 3, 32, 32) 218 | import time 219 | a = time.time() 220 | feats, logit = net(x, is_feat=True, preact=True) 221 | b = time.time() 222 | print(b - a) 223 | for f in feats: 224 | print(f.shape, f.min().item()) 225 | print(logit.shape) 226 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import resnet38, resnet110, resnet116, resnet14x2, resnet38x2, resnet110x2 2 | from .resnet import resnet8x4, resnet14x4, resnet32x4, resnet38x4 3 | from .vgg import vgg8_bn, vgg13_bn 4 | from .mobilenetv2 import mobile_half, mobile_half_double 5 | from .ShuffleNetv1 import ShuffleV1 6 | from .ShuffleNetv2 import ShuffleV2, ShuffleV2_1_5 7 | 8 | from .resnet_imagenet import resnet18, resnet34, resnet50, wide_resnet50_2, resnext50_32x4d 9 | from .resnet_imagenet import wide_resnet10_2, wide_resnet18_2, wide_resnet34_2 10 | from .mobilenetv2_imagenet import mobilenet_v2 11 | from .shuffleNetv2_imagenet import shufflenet_v2_x1_0 12 | 13 | model_dict = { 14 | 'resnet38': resnet38, 15 | 'resnet110': resnet110, 16 | 'resnet116': resnet116, 17 | 'resnet14x2': resnet14x2, 18 | 'resnet38x2': resnet38x2, 19 | 'resnet110x2': resnet110x2, 20 | 'resnet8x4': resnet8x4, 21 | 'resnet14x4': resnet14x4, 22 | 'resnet32x4': resnet32x4, 23 | 'resnet38x4': resnet38x4, 24 | 'vgg8': vgg8_bn, 25 | 'vgg13': vgg13_bn, 26 | 'MobileNetV2': mobile_half, 27 | 'MobileNetV2_1_0': mobile_half_double, 28 | 'ShuffleV1': ShuffleV1, 29 | 'ShuffleV2': ShuffleV2, 30 | 'ShuffleV2_1_5': ShuffleV2_1_5, 31 | 32 | 'ResNet18': resnet18, 33 | 'ResNet34': resnet34, 34 | 'ResNet50': resnet50, 35 | 'resnext50_32x4d': resnext50_32x4d, 36 | 'ResNet10x2': wide_resnet10_2, 37 | 'ResNet18x2': wide_resnet18_2, 38 | 'ResNet34x2': wide_resnet34_2, 39 | 'wrn_50_2': wide_resnet50_2, 40 | 41 | 'MobileNetV2_Imagenet': mobilenet_v2, 42 | 'ShuffleV2_Imagenet': shufflenet_v2_x1_0, 43 | } 44 | -------------------------------------------------------------------------------- /models/__pycache__/ShuffleNetv1.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/ShuffleNetv1.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/ShuffleNetv1.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/ShuffleNetv1.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/ShuffleNetv2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/ShuffleNetv2.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/ShuffleNetv2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/ShuffleNetv2.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/mobilenetv2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/mobilenetv2.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/mobilenetv2.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/mobilenetv2.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/mobilenetv2_imagenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/mobilenetv2_imagenet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/mobilenetv2_imagenet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/mobilenetv2_imagenet.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/resnet.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet_imagenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/resnet_imagenet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet_imagenet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/resnet_imagenet.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/shuffleNetv2_imagenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/shuffleNetv2_imagenet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/shuffleNetv2_imagenet.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/shuffleNetv2_imagenet.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /models/__pycache__/vgg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/vgg.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/vgg.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/vgg.cpython-39.pyc -------------------------------------------------------------------------------- /models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | MobileNetV2 implementation used in 3 | 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import math 9 | 10 | __all__ = ['mobilenetv2_T_w', 'mobile_half'] 11 | 12 | BN = None 13 | 14 | 15 | def conv_bn(inp, oup, stride): 16 | return nn.Sequential( 17 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 18 | nn.BatchNorm2d(oup), 19 | nn.ReLU(inplace=True) 20 | ) 21 | 22 | 23 | def conv_1x1_bn(inp, oup): 24 | return nn.Sequential( 25 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 26 | nn.BatchNorm2d(oup), 27 | nn.ReLU(inplace=True) 28 | ) 29 | 30 | 31 | class InvertedResidual(nn.Module): 32 | def __init__(self, inp, oup, stride, expand_ratio): 33 | super(InvertedResidual, self).__init__() 34 | self.blockname = None 35 | 36 | self.stride = stride 37 | assert stride in [1, 2] 38 | 39 | self.use_res_connect = self.stride == 1 and inp == oup 40 | 41 | self.conv = nn.Sequential( 42 | # pw 43 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 44 | nn.BatchNorm2d(inp * expand_ratio), 45 | nn.ReLU(inplace=True), 46 | # dw 47 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, stride, 1, groups=inp * expand_ratio, bias=False), 48 | nn.BatchNorm2d(inp * expand_ratio), 49 | nn.ReLU(inplace=True), 50 | # pw-linear 51 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 52 | nn.BatchNorm2d(oup), 53 | ) 54 | self.names = ['0', '1', '2', '3', '4', '5', '6', '7'] 55 | 56 | def forward(self, x): 57 | t = x 58 | if self.use_res_connect: 59 | return t + self.conv(x) 60 | else: 61 | return self.conv(x) 62 | 63 | 64 | class MobileNetV2(nn.Module): 65 | """mobilenetV2""" 66 | def __init__(self, T, 67 | feature_dim, 68 | input_size=32, 69 | width_mult=1., 70 | remove_avg=False): 71 | super(MobileNetV2, self).__init__() 72 | self.remove_avg = remove_avg 73 | 74 | # setting of inverted residual blocks 75 | self.interverted_residual_setting = [ 76 | # t, c, n, s 77 | [1, 16, 1, 1], 78 | [T, 24, 2, 1], 79 | [T, 32, 3, 2], 80 | [T, 64, 4, 2], 81 | [T, 96, 3, 1], 82 | [T, 160, 3, 2], 83 | [T, 320, 1, 1], 84 | ] 85 | 86 | # building first layer 87 | assert input_size % 32 == 0 88 | input_channel = int(32 * width_mult) 89 | self.conv1 = conv_bn(3, input_channel, 2) 90 | 91 | # building inverted residual blocks 92 | self.blocks = nn.ModuleList([]) 93 | for t, c, n, s in self.interverted_residual_setting: 94 | output_channel = int(c * width_mult) 95 | layers = [] 96 | strides = [s] + [1] * (n - 1) 97 | for stride in strides: 98 | layers.append( 99 | InvertedResidual(input_channel, output_channel, stride, t) 100 | ) 101 | input_channel = output_channel 102 | self.blocks.append(nn.Sequential(*layers)) 103 | 104 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 105 | self.conv2 = conv_1x1_bn(input_channel, self.last_channel) 106 | 107 | # building classifier 108 | self.classifier = nn.Sequential( 109 | # nn.Dropout(0.5), 110 | nn.Linear(self.last_channel, feature_dim), 111 | ) 112 | 113 | H = input_size // (32//2) 114 | self.avgpool = nn.AvgPool2d(H, ceil_mode=True) 115 | 116 | self._initialize_weights() 117 | print(T, width_mult) 118 | 119 | def get_feat_modules(self): 120 | feat_m = nn.ModuleList([]) 121 | feat_m.append(self.conv1) 122 | feat_m.append(self.blocks) 123 | return feat_m 124 | 125 | def forward(self, x, is_feat=False, preact=False): 126 | 127 | out = self.conv1(x) 128 | f0 = out 129 | 130 | out = self.blocks[0](out) 131 | out = self.blocks[1](out) 132 | f1 = out 133 | out = self.blocks[2](out) 134 | f2 = out 135 | out = self.blocks[3](out) 136 | out = self.blocks[4](out) 137 | f3 = out 138 | out = self.blocks[5](out) 139 | out = self.blocks[6](out) 140 | f4 = out 141 | 142 | out = self.conv2(out) 143 | 144 | if not self.remove_avg: 145 | out = self.avgpool(out) 146 | out = out.view(out.size(0), -1) 147 | f5 = out 148 | out = self.classifier(out) 149 | 150 | if is_feat: 151 | return [f0, f1, f2, f3, f4, f5], out 152 | else: 153 | return out 154 | 155 | def _initialize_weights(self): 156 | for m in self.modules(): 157 | if isinstance(m, nn.Conv2d): 158 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 159 | m.weight.data.normal_(0, math.sqrt(2. / n)) 160 | if m.bias is not None: 161 | m.bias.data.zero_() 162 | elif isinstance(m, nn.BatchNorm2d): 163 | m.weight.data.fill_(1) 164 | m.bias.data.zero_() 165 | elif isinstance(m, nn.Linear): 166 | n = m.weight.size(1) 167 | m.weight.data.normal_(0, 0.01) 168 | m.bias.data.zero_() 169 | 170 | 171 | def mobilenetv2_T_w(T, W, feature_dim=100): 172 | model = MobileNetV2(T=T, feature_dim=feature_dim, width_mult=W) 173 | return model 174 | 175 | # To be consistent with the previous paper (CRD), MobileNetV2 is instantiated by mobile_half 176 | def mobile_half(num_classes): 177 | return mobilenetv2_T_w(6, 0.5, num_classes) 178 | 179 | # MobileNetV2x2 is instantiated by mobile_half_double 180 | def mobile_half_double(num_classes): 181 | return mobilenetv2_T_w(6, 1.0, num_classes) 182 | 183 | if __name__ == '__main__': 184 | x = torch.randn(2, 3, 32, 32) 185 | 186 | net = mobile_half(100) 187 | 188 | feats, logit = net(x, is_feat=True, preact=True) 189 | for f in feats: 190 | print(f.shape, f.min().item()) 191 | print(logit.shape) 192 | 193 | num_params_stu = (sum(p.numel() for p in net.parameters())/1000000.0) 194 | print('Total params_stu: {:.3f} M'.format(num_params_stu)) 195 | -------------------------------------------------------------------------------- /models/mobilenetv2_imagenet.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenet.py 2 | # Only two changes: 3 | # 1. MobileNetV2.forward() is modified to return inner feature maps. 4 | # 2. merge utils.py into this file to import load_state_dict_from_url. 5 | 6 | import torch 7 | from torch import nn 8 | 9 | # https://github.com/pytorch/vision/blob/master/torchvision/models/utils.py 10 | try: 11 | from torch.hub import load_state_dict_from_url 12 | except ImportError: 13 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 14 | 15 | 16 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 17 | 18 | 19 | model_urls = { 20 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 21 | } 22 | 23 | 24 | def _make_divisible(v, divisor, min_value=None): 25 | """ 26 | This function is taken from the original tf repo. 27 | It ensures that all layers have a channel number that is divisible by 8 28 | It can be seen here: 29 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 30 | :param v: 31 | :param divisor: 32 | :param min_value: 33 | :return: 34 | """ 35 | if min_value is None: 36 | min_value = divisor 37 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 38 | # Make sure that round down does not go down by more than 10%. 39 | if new_v < 0.9 * v: 40 | new_v += divisor 41 | return new_v 42 | 43 | 44 | class ConvBNReLU(nn.Sequential): 45 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None): 46 | padding = (kernel_size - 1) // 2 47 | if norm_layer is None: 48 | norm_layer = nn.BatchNorm2d 49 | super(ConvBNReLU, self).__init__( 50 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 51 | norm_layer(out_planes), 52 | nn.ReLU6(inplace=True) 53 | ) 54 | 55 | 56 | class InvertedResidual(nn.Module): 57 | def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None): 58 | super(InvertedResidual, self).__init__() 59 | self.stride = stride 60 | assert stride in [1, 2] 61 | 62 | if norm_layer is None: 63 | norm_layer = nn.BatchNorm2d 64 | 65 | hidden_dim = int(round(inp * expand_ratio)) 66 | self.use_res_connect = self.stride == 1 and inp == oup 67 | 68 | layers = [] 69 | if expand_ratio != 1: 70 | # pw 71 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) 72 | layers.extend([ 73 | # dw 74 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer), 75 | # pw-linear 76 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 77 | norm_layer(oup), 78 | ]) 79 | self.conv = nn.Sequential(*layers) 80 | 81 | def forward(self, x): 82 | if self.use_res_connect: 83 | return x + self.conv(x) 84 | else: 85 | return self.conv(x) 86 | 87 | 88 | class MobileNetV2(nn.Module): 89 | def __init__(self, 90 | num_classes=1000, 91 | width_mult=1.0, 92 | inverted_residual_setting=None, 93 | round_nearest=8, 94 | block=None, 95 | norm_layer=None): 96 | """ 97 | MobileNet V2 main class 98 | 99 | Args: 100 | num_classes (int): Number of classes 101 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 102 | inverted_residual_setting: Network structure 103 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 104 | Set to 1 to turn off rounding 105 | block: Module specifying inverted residual building block for mobilenet 106 | norm_layer: Module specifying the normalization layer to use 107 | 108 | """ 109 | super(MobileNetV2, self).__init__() 110 | 111 | if block is None: 112 | block = InvertedResidual 113 | 114 | if norm_layer is None: 115 | norm_layer = nn.BatchNorm2d 116 | 117 | input_channel = 32 118 | last_channel = 1280 119 | 120 | if inverted_residual_setting is None: 121 | inverted_residual_setting = [ 122 | # t, c, n, s 123 | [1, 16, 1, 1], 124 | [6, 24, 2, 2], 125 | [6, 32, 3, 2], 126 | [6, 64, 4, 2], 127 | [6, 96, 3, 1], 128 | [6, 160, 3, 2], 129 | [6, 320, 1, 1], 130 | ] 131 | 132 | # only check the first element, assuming user knows t,c,n,s are required 133 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 134 | raise ValueError("inverted_residual_setting should be non-empty " 135 | "or a 4-element list, got {}".format(inverted_residual_setting)) 136 | 137 | # building first layer 138 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 139 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 140 | features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)] 141 | # building inverted residual blocks 142 | for t, c, n, s in inverted_residual_setting: 143 | output_channel = _make_divisible(c * width_mult, round_nearest) 144 | for i in range(n): 145 | stride = s if i == 0 else 1 146 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer)) 147 | input_channel = output_channel 148 | # building last several layers 149 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer)) 150 | # do not use nn.Sequential 151 | # self.features = nn.Sequential(*features) 152 | self.features = nn.ModuleList(features) 153 | 154 | # building classifier 155 | self.classifier = nn.Sequential( 156 | nn.Dropout(0.2), 157 | nn.Linear(self.last_channel, num_classes), 158 | ) 159 | 160 | # weight initialization 161 | for m in self.modules(): 162 | if isinstance(m, nn.Conv2d): 163 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 164 | if m.bias is not None: 165 | nn.init.zeros_(m.bias) 166 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 167 | nn.init.ones_(m.weight) 168 | nn.init.zeros_(m.bias) 169 | elif isinstance(m, nn.Linear): 170 | nn.init.normal_(m.weight, 0, 0.01) 171 | nn.init.zeros_(m.bias) 172 | 173 | def _forward_impl(self, x): 174 | # This exists since TorchScript doesn't support inheritance, so the superclass method 175 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 176 | # x = self.features(x) 177 | for module in self.features: 178 | x = module(x) 179 | # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0] 180 | x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1) 181 | x = self.classifier(x) 182 | return x 183 | 184 | def forward(self, x, is_feat=False): 185 | if not is_feat: 186 | return self._forward_impl(x) 187 | 188 | splits = [0, 1, 4, 7, 14, 18] 189 | hidden_layers = [] 190 | for left, right in zip(splits, splits[1:]): 191 | for module in self.features[left:right]: 192 | x = module(x) 193 | hidden_layers.append(x) 194 | for module in self.features[splits[-1]:]: 195 | x = module(x) 196 | x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1) 197 | hidden_layers.append(x) 198 | x = self.classifier(x) 199 | return hidden_layers, x 200 | 201 | def mobilenet_v2(pretrained=False, progress=True, **kwargs): 202 | """ 203 | Constructs a MobileNetV2 architecture from 204 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 205 | 206 | Args: 207 | pretrained (bool): If True, returns a model pre-trained on ImageNet 208 | progress (bool): If True, displays a progress bar of the download to stderr 209 | """ 210 | model = MobileNetV2(**kwargs) 211 | if pretrained: 212 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 213 | progress=progress) 214 | model.load_state_dict(state_dict) 215 | return model 216 | 217 | if __name__ == '__main__': 218 | #x = torch.randn(2, 3, 32, 32) 219 | x = torch.randn(2, 3, 224, 224) 220 | 221 | #net = mobile_half(100) 222 | net = MobileNetV2() 223 | 224 | feats, logit = net(x, is_feat=True) 225 | for f in feats: 226 | print(f.shape, f.min().item()) 227 | print(logit.shape) 228 | 229 | num_params_stu = (sum(p.numel() for p in net.parameters())/1000000.0) 230 | print('Total params_stu: {:.3f} M'.format(num_params_stu)) -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | ''' 4 | ResNet for CIFAR-10/100 Dataset (Only BasicBlock is used). 5 | Reference: 6 | 1. https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 7 | 2. https://github.com/facebook/fb.resnet.torch/blob/master/models/resnet.lua 8 | 3. Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 9 | Deep Residual Learning for Image Recognition (CVPR 2016). https://arxiv.org/abs/1512.03385 10 | 11 | The Wide ResNet model is the same as ResNet except for the number of channels 12 | is double/quadruple/k-time larger in every basicblock. 13 | Reference: 14 | 1. https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 15 | 2. Sergey Zagoruyko and Nikos Komodakis 16 | Wide Residual Networks (BMVC 2016) http://arxiv.org/abs/1605.07146 17 | 18 | P.S. 19 | Following the previous repository "https://github.com/HobbitLong/RepDistiller", the num_filters of the first conv is doubled in ResNet-8x4/32x4. 20 | The wide ResNet model in the "/RepDistiller/models/wrn.py" is almost the same as ResNet model in "/RepDistiller/models/resnet.py". 21 | For example, wrn_40_2 in "/RepDistiller/models/wrn.py" almost equals to resnet38x2 in "/RepDistiller/models/resnet.py". 22 | The only difference is that resnet38x2 has additional three BN layers, which leads to 2*(16+32+64)*k parameters [k=2 in this comparison]. 23 | Therefore, it is recommanded to directly use this file for the implementation of the Wide ResNet model. 24 | ''' 25 | 26 | import torch.nn as nn 27 | 28 | __all__ = ['resnet'] 29 | 30 | def conv3x3(in_planes, out_planes, stride=1): 31 | """3x3 convolution with padding""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 33 | 34 | def conv1x1(in_planes, out_planes, stride=1): 35 | """1x1 convolution""" 36 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 37 | 38 | class BasicBlock(nn.Module): 39 | expansion = 1 40 | 41 | def __init__(self, inplanes, planes, stride=1, downsample=None): 42 | super(BasicBlock, self).__init__() 43 | self.conv1 = conv3x3(inplanes, planes, stride) 44 | self.bn1 = nn.BatchNorm2d(planes) 45 | self.relu = nn.ReLU(inplace=True) 46 | self.conv2 = conv3x3(planes, planes) 47 | self.bn2 = nn.BatchNorm2d(planes) 48 | self.downsample = downsample 49 | self.stride = stride 50 | 51 | def forward(self, x): 52 | residual = x 53 | 54 | out = self.conv1(x) 55 | out = self.bn1(out) 56 | out = self.relu(out) 57 | 58 | out = self.conv2(out) 59 | out = self.bn2(out) 60 | 61 | if self.downsample is not None: 62 | residual = self.downsample(x) 63 | 64 | out += residual 65 | out = self.relu(out) 66 | return out 67 | 68 | class Bottleneck(nn.Module): 69 | expansion = 4 70 | 71 | def __init__(self, inplanes, planes, stride=1, downsample=None): 72 | super(Bottleneck, self).__init__() 73 | self.conv1 = conv1x1(inplanes, planes) 74 | self.bn1 = nn.BatchNorm2d(planes) 75 | self.conv2 = conv3x3(planes, planes, stride=1) 76 | self.bn2 = nn.BatchNorm2d(planes) 77 | self.conv3 = conv1x1(planes, planes * 4) 78 | self.bn3 = nn.BatchNorm2d(planes * 4) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.downsample = downsample 81 | self.stride = stride 82 | 83 | def forward(self, x): 84 | residual = x 85 | 86 | out = self.conv1(x) 87 | out = self.bn1(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv2(out) 91 | out = self.bn2(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv3(out) 95 | out = self.bn3(out) 96 | 97 | if self.downsample is not None: 98 | residual = self.downsample(x) 99 | 100 | out += residual 101 | out = self.relu(out) 102 | return out 103 | 104 | class ResNet(nn.Module): 105 | 106 | def __init__(self, depth, num_filters, block_name='BasicBlock', num_classes=10): 107 | super(ResNet, self).__init__() 108 | # Model type specifies number of layers for CIFAR-10 model 109 | if block_name.lower() == 'basicblock': 110 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' 111 | n = (depth - 2) // 6 112 | block = BasicBlock 113 | elif block_name.lower() == 'bottleneck': 114 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' 115 | n = (depth - 2) // 9 116 | block = Bottleneck 117 | else: 118 | raise ValueError('block_name shoule be Basicblock or Bottleneck') 119 | 120 | self.inplanes = num_filters[0] 121 | self.conv1 = nn.Conv2d(3, num_filters[0], kernel_size=3, padding=1, bias=False) 122 | self.bn1 = nn.BatchNorm2d(num_filters[0]) 123 | self.relu = nn.ReLU(inplace=True) 124 | self.layer1 = self._make_layer(block, num_filters[1], n) 125 | self.layer2 = self._make_layer(block, num_filters[2], n, stride=2) 126 | self.layer3 = self._make_layer(block, num_filters[3], n, stride=2) 127 | self.avgpool = nn.AdaptiveAvgPool2d((1,1)) 128 | self.fc = nn.Linear(num_filters[3] * block.expansion, num_classes) 129 | 130 | for m in self.modules(): 131 | if isinstance(m, nn.Conv2d): 132 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 133 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 134 | nn.init.constant_(m.weight, 1) 135 | nn.init.constant_(m.bias, 0) 136 | 137 | def _make_layer(self, block, planes, blocks, stride=1): 138 | downsample = None 139 | if stride != 1 or self.inplanes != planes * block.expansion: 140 | downsample = nn.Sequential( 141 | conv1x1(self.inplanes, planes * block.expansion, stride), 142 | nn.BatchNorm2d(planes * block.expansion), 143 | ) 144 | 145 | layers = list([]) 146 | layers.append(block(self.inplanes, planes, stride, downsample)) 147 | self.inplanes = planes * block.expansion 148 | for i in range(1, blocks): 149 | layers.append(block(self.inplanes, planes)) 150 | 151 | return nn.Sequential(*layers) 152 | 153 | def get_feat_modules(self): 154 | feat_m = nn.ModuleList([]) 155 | feat_m.append(self.conv1) 156 | feat_m.append(self.bn1) 157 | feat_m.append(self.relu) 158 | feat_m.append(self.layer1) 159 | feat_m.append(self.layer2) 160 | feat_m.append(self.layer3) 161 | feat_m.append(self.fc) 162 | return feat_m 163 | 164 | def forward(self, x, is_feat=False): 165 | 166 | x = self.conv1(x) 167 | x = self.bn1(x) 168 | x = self.relu(x) # 32x32 169 | f0 = x 170 | 171 | x = self.layer1(x) # 32x32 172 | f1 = x 173 | x = self.layer2(x) # 16x16 174 | f2 = x 175 | x = self.layer3(x) # 8x8 176 | f3 = x 177 | 178 | x = self.avgpool(x) 179 | x = x.view(x.size(0), -1) 180 | f4 = x 181 | x = self.fc(x) 182 | 183 | if is_feat: 184 | return [f0, f1, f2, f3, f4], x 185 | else: 186 | return x 187 | 188 | def resnet8(**kwargs): 189 | return ResNet(8, [16, 16, 32, 64], 'basicblock', **kwargs) 190 | 191 | def resnet14(**kwargs): 192 | return ResNet(14, [16, 16, 32, 64], 'basicblock', **kwargs) 193 | 194 | def resnet20(**kwargs): 195 | return ResNet(20, [16, 16, 32, 64], 'basicblock', **kwargs) 196 | 197 | def resnet32(**kwargs): 198 | return ResNet(32, [16, 16, 32, 64], 'basicblock', **kwargs) 199 | 200 | # wrn_40_1 (We use the wrn notation to be consistent with the previous work) 201 | def resnet38(**kwargs): 202 | return ResNet(38, [16, 16, 32, 64], 'basicblock', **kwargs) 203 | 204 | def resnet44(**kwargs): 205 | return ResNet(44, [16, 16, 32, 64], 'basicblock', **kwargs) 206 | 207 | def resnet56(**kwargs): 208 | return ResNet(56, [16, 16, 32, 64], 'basicblock', **kwargs) 209 | 210 | def resnet110(**kwargs): 211 | return ResNet(110, [16, 16, 32, 64], 'basicblock', **kwargs) 212 | 213 | def resnet116(**kwargs): 214 | return ResNet(116, [16, 16, 32, 64], 'basicblock', **kwargs) 215 | 216 | def resnet200(**kwargs): 217 | return ResNet(200, [16, 16, 32, 64], 'basicblock', **kwargs) 218 | 219 | # wrn_16_2 (We use the wrn notation to be consistent with the previous work) 220 | def resnet14x2(**kwargs): 221 | return ResNet(14, [16, 32, 64, 128], 'basicblock', **kwargs) 222 | 223 | # wrn_16_4 (We use the wrn notation to be consistent with the previous work) 224 | def resnet14x4(**kwargs): 225 | return ResNet(14, [32, 64, 128, 256], 'basicblock', **kwargs) 226 | 227 | # wrn_40_2 (We use the wrn notation to be consistent with the previous work) 228 | def resnet38x2(**kwargs): 229 | return ResNet(38, [16, 32, 64, 128], 'basicblock', **kwargs) 230 | 231 | def resnet110x2(**kwargs): 232 | return ResNet(110, [16, 32, 64, 128], 'basicblock', **kwargs) 233 | 234 | def resnet8x4(**kwargs): 235 | return ResNet(8, [32, 64, 128, 256], 'basicblock', **kwargs) 236 | 237 | def resnet20x4(**kwargs): 238 | return ResNet(20, [32, 64, 128, 256], 'basicblock', **kwargs) 239 | 240 | def resnet26x4(**kwargs): 241 | return ResNet(26, [32, 64, 128, 256], 'basicblock', **kwargs) 242 | 243 | def resnet32x4(**kwargs): 244 | return ResNet(32, [32, 64, 128, 256], 'basicblock', **kwargs) 245 | 246 | # wrn_40_4 (We use the wrn notation to be consistent with the previous work) 247 | def resnet38x4(**kwargs): 248 | return ResNet(38, [32, 64, 128, 256], 'basicblock', **kwargs) 249 | 250 | def resnet44x4(**kwargs): 251 | return ResNet(44, [32, 64, 128, 256], 'basicblock', **kwargs) 252 | 253 | def resnet56x4(**kwargs): 254 | return ResNet(56, [32, 64, 128, 256], 'basicblock', **kwargs) 255 | 256 | def resnet110x4(**kwargs): 257 | return ResNet(110, [32, 64, 128, 256], 'basicblock', **kwargs) 258 | 259 | if __name__ == '__main__': 260 | import torch 261 | 262 | x = torch.randn(2, 3, 32, 32) 263 | net = resnet32(num_classes=100) 264 | feats, logit = net(x, is_feat=True) 265 | 266 | for f in feats: 267 | print(f.shape, f.min().item()) 268 | print(logit.shape) 269 | 270 | #for i, m in enumerate(net.get_feat_modules()): 271 | # print(i, m) 272 | 273 | num_params_stu = (sum(p.numel() for p in net.parameters())/1000000.0) 274 | print('Total params_stu: {:.3f} M'.format(num_params_stu)) -------------------------------------------------------------------------------- /models/resnet_imagenet.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 2 | # Only two changes: 3 | # 1. Resnet.forward() is modified to return inner feature maps. 4 | # 2. merge utils.py into this file to import load_state_dict_from_url. 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | # https://github.com/pytorch/vision/blob/master/torchvision/models/utils.py 10 | try: 11 | from torch.hub import load_state_dict_from_url 12 | except ImportError: 13 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 14 | 15 | 16 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 17 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 18 | 'wide_resnet50_2', 'wide_resnet101_2', 'resnet34x4'] 19 | 20 | 21 | model_urls = { 22 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 23 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 24 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 25 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 26 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 27 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 28 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 29 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 30 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 31 | } 32 | 33 | 34 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 35 | """3x3 convolution with padding""" 36 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 37 | padding=dilation, groups=groups, bias=False, dilation=dilation) 38 | 39 | 40 | def conv1x1(in_planes, out_planes, stride=1): 41 | """1x1 convolution""" 42 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 43 | 44 | 45 | class BasicBlock(nn.Module): 46 | expansion = 1 47 | 48 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 49 | base_width=64, dilation=1, norm_layer=None): 50 | super(BasicBlock, self).__init__() 51 | if norm_layer is None: 52 | norm_layer = nn.BatchNorm2d 53 | # if groups != 1 or base_width != 64: 54 | # raise ValueError('BasicBlock only supports groups=1 and base_width=64') 55 | if dilation > 1: 56 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 57 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 58 | self.conv1 = conv3x3(inplanes, planes, stride) 59 | self.bn1 = norm_layer(planes) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.conv2 = conv3x3(planes, planes) 62 | self.bn2 = norm_layer(planes) 63 | self.downsample = downsample 64 | self.stride = stride 65 | 66 | def forward(self, x): 67 | identity = x 68 | 69 | out = self.conv1(x) 70 | out = self.bn1(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv2(out) 74 | out = self.bn2(out) 75 | 76 | if self.downsample is not None: 77 | identity = self.downsample(x) 78 | 79 | out += identity 80 | out = self.relu(out) 81 | 82 | return out 83 | 84 | 85 | class Bottleneck(nn.Module): 86 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 87 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 88 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 89 | # This variant is also known as ResNet V1.5 and improves accuracy according to 90 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 91 | 92 | expansion = 4 93 | 94 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 95 | base_width=64, dilation=1, norm_layer=None): 96 | super(Bottleneck, self).__init__() 97 | if norm_layer is None: 98 | norm_layer = nn.BatchNorm2d 99 | width = int(planes * (base_width / 64.)) * groups 100 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 101 | self.conv1 = conv1x1(inplanes, width) 102 | self.bn1 = norm_layer(width) 103 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 104 | self.bn2 = norm_layer(width) 105 | self.conv3 = conv1x1(width, planes * self.expansion) 106 | self.bn3 = norm_layer(planes * self.expansion) 107 | self.relu = nn.ReLU(inplace=True) 108 | self.downsample = downsample 109 | self.stride = stride 110 | 111 | def forward(self, x): 112 | identity = x 113 | 114 | out = self.conv1(x) 115 | out = self.bn1(out) 116 | out = self.relu(out) 117 | 118 | out = self.conv2(out) 119 | out = self.bn2(out) 120 | out = self.relu(out) 121 | 122 | out = self.conv3(out) 123 | out = self.bn3(out) 124 | 125 | if self.downsample is not None: 126 | identity = self.downsample(x) 127 | 128 | out += identity 129 | out = self.relu(out) 130 | 131 | return out 132 | 133 | class ResNet(nn.Module): 134 | 135 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 136 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 137 | norm_layer=None): 138 | super(ResNet, self).__init__() 139 | if norm_layer is None: 140 | norm_layer = nn.BatchNorm2d 141 | self._norm_layer = norm_layer 142 | 143 | self.inplanes = 64 144 | self.dilation = 1 145 | self.multiplier = 1 146 | if replace_stride_with_dilation is None: 147 | # each element in the tuple indicates if we should replace 148 | # the 2x2 stride with a dilated convolution instead 149 | replace_stride_with_dilation = [False, False, False] 150 | if len(replace_stride_with_dilation) != 3: 151 | raise ValueError("replace_stride_with_dilation should be None " 152 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 153 | self.groups = groups 154 | self.base_width = width_per_group 155 | if self.base_width != 64 and (block is BasicBlock): 156 | self.multiplier = self.base_width // 64 157 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 158 | bias=False) 159 | self.bn1 = norm_layer(self.inplanes) 160 | self.relu = nn.ReLU(inplace=True) 161 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 162 | self.layer1 = self._make_layer(block, int(64 * self.multiplier), layers[0]) 163 | self.layer2 = self._make_layer(block, int(128 * self.multiplier), layers[1], stride=2, 164 | dilate=replace_stride_with_dilation[0]) 165 | self.layer3 = self._make_layer(block, int(256 * self.multiplier), layers[2], stride=2, 166 | dilate=replace_stride_with_dilation[1]) 167 | self.layer4 = self._make_layer(block, int(512 * self.multiplier), layers[3], stride=2, 168 | dilate=replace_stride_with_dilation[2]) 169 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 170 | self.fc = nn.Linear(int(512 * self.multiplier) * block.expansion, num_classes) 171 | 172 | for m in self.modules(): 173 | if isinstance(m, nn.Conv2d): 174 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 175 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 176 | nn.init.constant_(m.weight, 1) 177 | nn.init.constant_(m.bias, 0) 178 | 179 | # Zero-initialize the last BN in each residual branch, 180 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 181 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 182 | if zero_init_residual: 183 | for m in self.modules(): 184 | if isinstance(m, Bottleneck): 185 | nn.init.constant_(m.bn3.weight, 0) 186 | elif isinstance(m, BasicBlock): 187 | nn.init.constant_(m.bn2.weight, 0) 188 | 189 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 190 | norm_layer = self._norm_layer 191 | downsample = None 192 | previous_dilation = self.dilation 193 | if dilate: 194 | self.dilation *= stride 195 | stride = 1 196 | 197 | if stride != 1 or self.inplanes != planes * block.expansion: 198 | downsample = nn.Sequential( 199 | conv1x1(self.inplanes, planes * block.expansion, stride), 200 | norm_layer(planes * block.expansion), 201 | ) 202 | 203 | layers = [] 204 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 205 | self.base_width, previous_dilation, norm_layer)) 206 | self.inplanes = planes * block.expansion 207 | for _ in range(1, blocks): 208 | layers.append(block(self.inplanes, planes, groups=self.groups, 209 | base_width=self.base_width, dilation=self.dilation, 210 | norm_layer=norm_layer)) 211 | 212 | return nn.Sequential(*layers) 213 | 214 | def get_feat_modules(self): 215 | feat_m = nn.ModuleList([]) 216 | feat_m.append(self.conv1) 217 | feat_m.append(self.bn1) 218 | feat_m.append(self.relu) 219 | feat_m.append(self.maxpool) 220 | feat_m.append(self.layer1) 221 | feat_m.append(self.layer2) 222 | feat_m.append(self.layer3) 223 | feat_m.append(self.layer4) 224 | feat_m.append(self.fc) 225 | return feat_m 226 | 227 | def forward(self, x, is_feat=False): 228 | x = self.conv1(x) 229 | x = self.bn1(x) 230 | x = self.relu(x) 231 | x = self.maxpool(x) 232 | f0 = x 233 | 234 | x = self.layer1(x) 235 | f1 = x 236 | x = self.layer2(x) 237 | f2 = x 238 | x = self.layer3(x) 239 | f3 = x 240 | x = self.layer4(x) 241 | f4 = x 242 | 243 | x = self.avgpool(x) 244 | x = torch.flatten(x, 1) 245 | f5 = x 246 | x = self.fc(x) 247 | if is_feat: 248 | return [f0, f1, f2, f3, f4, f5], x 249 | else: 250 | return x 251 | 252 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 253 | model = ResNet(block, layers, **kwargs) 254 | if pretrained: 255 | state_dict = load_state_dict_from_url(model_urls[arch], 256 | progress=progress) 257 | model.load_state_dict(state_dict) 258 | return model 259 | 260 | 261 | def resnet18(pretrained=False, progress=True, **kwargs): 262 | r"""ResNet-18 model from 263 | `"Deep Residual Learning for Image Recognition" `_ 264 | 265 | Args: 266 | pretrained (bool): If True, returns a model pre-trained on ImageNet 267 | progress (bool): If True, displays a progress bar of the download to stderr 268 | """ 269 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 270 | **kwargs) 271 | 272 | def wide_resnet10_2(pretrained=False, progress=True, **kwargs): 273 | 274 | kwargs['width_per_group'] = 64 * 2 275 | return _resnet('wide_resnet10_2', BasicBlock, [1, 1, 1, 1], pretrained, progress, 276 | **kwargs) 277 | 278 | def wide_resnet18_2(pretrained=False, progress=True, **kwargs): 279 | 280 | kwargs['width_per_group'] = 64 * 2 281 | return _resnet('wide_resnet18_2', BasicBlock, [2, 2, 2, 2], pretrained, progress, 282 | **kwargs) 283 | 284 | def wide_resnet26_2(pretrained=False, progress=True, **kwargs): 285 | 286 | kwargs['width_per_group'] = 64 * 2 287 | return _resnet('wide_resnet26_2', BasicBlock, [3, 3, 3, 3], pretrained, progress, 288 | **kwargs) 289 | 290 | def resnet34(pretrained=False, progress=True, **kwargs): 291 | r"""ResNet-34 model from 292 | `"Deep Residual Learning for Image Recognition" `_ 293 | 294 | Args: 295 | pretrained (bool): If True, returns a model pre-trained on ImageNet 296 | progress (bool): If True, displays a progress bar of the download to stderr 297 | """ 298 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 299 | **kwargs) 300 | 301 | def resnet50(pretrained=False, progress=True, **kwargs): 302 | r"""ResNet-50 model from 303 | `"Deep Residual Learning for Image Recognition" `_ 304 | 305 | Args: 306 | pretrained (bool): If True, returns a model pre-trained on ImageNet 307 | progress (bool): If True, displays a progress bar of the download to stderr 308 | """ 309 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 310 | **kwargs) 311 | 312 | def resnet101(pretrained=False, progress=True, **kwargs): 313 | r"""ResNet-101 model from 314 | `"Deep Residual Learning for Image Recognition" `_ 315 | 316 | Args: 317 | pretrained (bool): If True, returns a model pre-trained on ImageNet 318 | progress (bool): If True, displays a progress bar of the download to stderr 319 | """ 320 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 321 | **kwargs) 322 | 323 | def resnet152(pretrained=False, progress=True, **kwargs): 324 | r"""ResNet-152 model from 325 | `"Deep Residual Learning for Image Recognition" `_ 326 | 327 | Args: 328 | pretrained (bool): If True, returns a model pre-trained on ImageNet 329 | progress (bool): If True, displays a progress bar of the download to stderr 330 | """ 331 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 332 | **kwargs) 333 | 334 | def wide_resnet34_2(pretrained=False, progress=True, **kwargs): 335 | 336 | kwargs['width_per_group'] = 64 * 2 337 | return _resnet('wide_resnet34_2', BasicBlock, [3, 4, 6, 3], pretrained, progress, 338 | **kwargs) 339 | 340 | def wide_resnet34_4(pretrained=False, progress=True, **kwargs): 341 | 342 | kwargs['width_per_group'] = 64 * 4 343 | return _resnet('wide_resnet34_4', BasicBlock, [3, 4, 6, 3], pretrained, progress, 344 | **kwargs) 345 | 346 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 347 | r"""ResNeXt-50 32x4d model from 348 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 349 | 350 | Args: 351 | pretrained (bool): If True, returns a model pre-trained on ImageNet 352 | progress (bool): If True, displays a progress bar of the download to stderr 353 | """ 354 | kwargs['groups'] = 32 355 | kwargs['width_per_group'] = 4 356 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 357 | pretrained, progress, **kwargs) 358 | 359 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 360 | r"""ResNeXt-101 32x8d model from 361 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 362 | 363 | Args: 364 | pretrained (bool): If True, returns a model pre-trained on ImageNet 365 | progress (bool): If True, displays a progress bar of the download to stderr 366 | """ 367 | kwargs['groups'] = 32 368 | kwargs['width_per_group'] = 8 369 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 370 | pretrained, progress, **kwargs) 371 | 372 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 373 | r"""Wide ResNet-50-2 model from 374 | `"Wide Residual Networks" `_ 375 | 376 | The model is the same as ResNet except for the bottleneck number of channels 377 | which is twice larger in every block. The number of channels in outer 1x1 378 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 379 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 380 | 381 | Args: 382 | pretrained (bool): If True, returns a model pre-trained on ImageNet 383 | progress (bool): If True, displays a progress bar of the download to stderr 384 | """ 385 | kwargs['width_per_group'] = 64 * 2 386 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 387 | pretrained, progress, **kwargs) 388 | 389 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 390 | r"""Wide ResNet-101-2 model from 391 | `"Wide Residual Networks" `_ 392 | 393 | The model is the same as ResNet except for the bottleneck number of channels 394 | which is twice larger in every block. The number of channels in outer 1x1 395 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 396 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 397 | 398 | Args: 399 | pretrained (bool): If True, returns a model pre-trained on ImageNet 400 | progress (bool): If True, displays a progress bar of the download to stderr 401 | """ 402 | kwargs['width_per_group'] = 64 * 2 403 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 404 | pretrained, progress, **kwargs) 405 | 406 | if __name__ == '__main__': 407 | 408 | x = torch.randn(64, 3, 224, 224) 409 | 410 | net = wide_resnet10_2() 411 | 412 | feats, logit = net(x, is_feat=True) 413 | for f in feats: 414 | print(f.shape, f.min().item()) 415 | print(logit.shape) 416 | 417 | num_params_stu = (sum(p.numel() for p in net.parameters())/1000000.0) 418 | print('Total params_stu: {:.3f} M'.format(num_params_stu)) -------------------------------------------------------------------------------- /models/shuffleNetv2_imagenet.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/pytorch/vision/blob/master/torchvision/models/shufflenetv2.py 2 | # Only two changes: 3 | # 1. ShuffleNet is modified to return inner feature maps. 4 | # 2. merge utils.py into this file to import load_state_dict_from_url. 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | # https://github.com/pytorch/vision/blob/master/torchvision/models/utils.py 10 | try: 11 | from torch.hub import load_state_dict_from_url 12 | except ImportError: 13 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 14 | 15 | 16 | 17 | __all__ = [ 18 | 'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 19 | 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0' 20 | ] 21 | 22 | model_urls = { 23 | 'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth', 24 | 'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth', 25 | 'shufflenetv2_x1.5': None, 26 | 'shufflenetv2_x2.0': None, 27 | } 28 | 29 | 30 | def channel_shuffle(x, groups): 31 | # type: (torch.Tensor, int) -> torch.Tensor 32 | batchsize, num_channels, height, width = x.data.size() 33 | channels_per_group = num_channels // groups 34 | 35 | # reshape 36 | x = x.view(batchsize, groups, 37 | channels_per_group, height, width) 38 | 39 | x = torch.transpose(x, 1, 2).contiguous() 40 | 41 | # flatten 42 | x = x.view(batchsize, -1, height, width) 43 | 44 | return x 45 | 46 | 47 | class InvertedResidual(nn.Module): 48 | def __init__(self, inp, oup, stride): 49 | super(InvertedResidual, self).__init__() 50 | 51 | if not (1 <= stride <= 3): 52 | raise ValueError('illegal stride value') 53 | self.stride = stride 54 | 55 | branch_features = oup // 2 56 | assert (self.stride != 1) or (inp == branch_features << 1) 57 | 58 | if self.stride > 1: 59 | self.branch1 = nn.Sequential( 60 | self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), 61 | nn.BatchNorm2d(inp), 62 | nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 63 | nn.BatchNorm2d(branch_features), 64 | nn.ReLU(inplace=True), 65 | ) 66 | else: 67 | self.branch1 = nn.Sequential() 68 | 69 | self.branch2 = nn.Sequential( 70 | nn.Conv2d(inp if (self.stride > 1) else branch_features, 71 | branch_features, kernel_size=1, stride=1, padding=0, bias=False), 72 | nn.BatchNorm2d(branch_features), 73 | nn.ReLU(inplace=True), 74 | self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), 75 | nn.BatchNorm2d(branch_features), 76 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 77 | nn.BatchNorm2d(branch_features), 78 | nn.ReLU(inplace=True), 79 | ) 80 | 81 | @staticmethod 82 | def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): 83 | return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) 84 | 85 | def forward(self, x): 86 | if self.stride == 1: 87 | x1, x2 = x.chunk(2, dim=1) 88 | out = torch.cat((x1, self.branch2(x2)), dim=1) 89 | else: 90 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) 91 | 92 | out = channel_shuffle(out, 2) 93 | 94 | return out 95 | 96 | 97 | class ShuffleNetV2(nn.Module): 98 | def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, inverted_residual=InvertedResidual): 99 | super(ShuffleNetV2, self).__init__() 100 | 101 | if len(stages_repeats) != 3: 102 | raise ValueError('expected stages_repeats as list of 3 positive ints') 103 | if len(stages_out_channels) != 5: 104 | raise ValueError('expected stages_out_channels as list of 5 positive ints') 105 | self._stage_out_channels = stages_out_channels 106 | 107 | input_channels = 3 108 | output_channels = self._stage_out_channels[0] 109 | self.conv1 = nn.Sequential( 110 | nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), 111 | nn.BatchNorm2d(output_channels), 112 | nn.ReLU(inplace=True), 113 | ) 114 | input_channels = output_channels 115 | 116 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 117 | 118 | stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] 119 | for name, repeats, output_channels in zip( 120 | stage_names, stages_repeats, self._stage_out_channels[1:]): 121 | seq = [inverted_residual(input_channels, output_channels, 2)] 122 | for i in range(repeats - 1): 123 | seq.append(inverted_residual(output_channels, output_channels, 1)) 124 | setattr(self, name, nn.Sequential(*seq)) 125 | input_channels = output_channels 126 | 127 | output_channels = self._stage_out_channels[-1] 128 | #self.conv5 = nn.Sequential( 129 | # nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), 130 | # nn.BatchNorm2d(output_channels), 131 | # nn.ReLU(inplace=True), 132 | #) 133 | 134 | self.fc = nn.Linear(output_channels, num_classes) 135 | 136 | def get_feat_modules(self): 137 | feat_m = nn.ModuleList([]) 138 | feat_m.append(self.conv1) 139 | feat_m.append(self.maxpool) 140 | feat_m.append(self.stage1) 141 | feat_m.append(self.stage2) 142 | feat_m.append(self.stage3) 143 | feat_m.append(self.stage4) 144 | feat_m.append(self.conv5) 145 | feat_m.append(self.fc) 146 | return feat_m 147 | 148 | def forward(self, x, is_feat=False): 149 | hidden_layers = [] 150 | x = self.conv1(x) 151 | x = self.maxpool(x) 152 | hidden_layers.append(x) 153 | x = self.stage2(x) 154 | hidden_layers.append(x) 155 | x = self.stage3(x) 156 | hidden_layers.append(x) 157 | x = self.stage4(x) 158 | x = self.conv5(x) 159 | hidden_layers.append(x) 160 | x = x.mean([2, 3]) # globalpool 161 | hidden_layers.append(x) 162 | x = self.fc(x) 163 | if not is_feat: 164 | return x 165 | else: 166 | return hidden_layers, x 167 | 168 | 169 | def _shufflenetv2(arch, pretrained, progress, *args, **kwargs): 170 | model = ShuffleNetV2(*args, **kwargs) 171 | 172 | if pretrained: 173 | model_url = model_urls[arch] 174 | if model_url is None: 175 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) 176 | else: 177 | state_dict = load_state_dict_from_url(model_url, progress=progress) 178 | model.load_state_dict(state_dict) 179 | 180 | return model 181 | 182 | 183 | def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs): 184 | """ 185 | Constructs a ShuffleNetV2 with 0.5x output channels, as described in 186 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 187 | `_. 188 | 189 | Args: 190 | pretrained (bool): If True, returns a model pre-trained on ImageNet 191 | progress (bool): If True, displays a progress bar of the download to stderr 192 | """ 193 | return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, 194 | [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) 195 | 196 | 197 | 198 | def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs): 199 | """ 200 | Constructs a ShuffleNetV2 with 1.0x output channels, as described in 201 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 202 | `_. 203 | 204 | Args: 205 | pretrained (bool): If True, returns a model pre-trained on ImageNet 206 | progress (bool): If True, displays a progress bar of the download to stderr 207 | """ 208 | return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, 209 | [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) 210 | 211 | 212 | 213 | def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs): 214 | """ 215 | Constructs a ShuffleNetV2 with 1.5x output channels, as described in 216 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 217 | `_. 218 | 219 | Args: 220 | pretrained (bool): If True, returns a model pre-trained on ImageNet 221 | progress (bool): If True, displays a progress bar of the download to stderr 222 | """ 223 | return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, 224 | [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) 225 | 226 | 227 | 228 | def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs): 229 | """ 230 | Constructs a ShuffleNetV2 with 2.0x output channels, as described in 231 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 232 | `_. 233 | 234 | Args: 235 | pretrained (bool): If True, returns a model pre-trained on ImageNet 236 | progress (bool): If True, displays a progress bar of the download to stderr 237 | """ 238 | return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, 239 | [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) 240 | 241 | if __name__ == '__main__': 242 | #x = torch.randn(2, 3, 32, 32) 243 | x = torch.randn(2, 3, 224, 224) 244 | 245 | #net = mobile_half(100) 246 | net = shufflenet_v2_x1_0() 247 | 248 | feats, logit = net(x, is_feat=True) 249 | for f in feats: 250 | print(f.shape, f.min().item()) 251 | print(logit.shape) 252 | 253 | num_params_stu = (sum(p.numel() for p in net.parameters())/1000000.0) 254 | print('Total params_stu: {:.3f} M'.format(num_params_stu)) -------------------------------------------------------------------------------- /models/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class ConvReg(nn.Module): 8 | """Convolutional regression for FitNet (feature-map layer)""" 9 | def __init__(self, s_shape, t_shape): 10 | super(ConvReg, self).__init__() 11 | _, s_C, s_H, s_W = s_shape 12 | _, t_C, t_H, t_W = t_shape 13 | self.s_H = s_H 14 | self.t_H = t_H 15 | if s_H == 2 * t_H: 16 | self.conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1) 17 | elif s_H * 2 == t_H: 18 | self.conv = nn.ConvTranspose2d(s_C, t_C, kernel_size=4, stride=2, padding=1) 19 | elif s_H >= t_H: 20 | self.conv = nn.Conv2d(s_C, t_C, kernel_size=(1+s_H-t_H, 1+s_W-t_W)) 21 | else: 22 | self.conv = nn.Conv2d(s_C, t_C, kernel_size=3, padding=1, stride=1) 23 | self.bn = nn.BatchNorm2d(t_C) 24 | self.relu = nn.ReLU(inplace=True) 25 | 26 | def forward(self, x, t): 27 | x = self.conv(x) 28 | if self.s_H == 2 * self.t_H or self.s_H * 2 == self.t_H or self.s_H >= self.t_H: 29 | return self.relu(self.bn(x)), t 30 | else: 31 | return self.relu(self.bn(x)), F.adaptive_avg_pool2d(t, (self.s_H, self.s_H)) 32 | 33 | class SelfA(nn.Module): 34 | """Cross-layer Self Attention""" 35 | def __init__(self, feat_dim, s_n, t_n, soft, factor=4): 36 | super(SelfA, self).__init__() 37 | 38 | self.soft = soft 39 | self.s_len = len(s_n) 40 | self.t_len = len(t_n) 41 | self.feat_dim = feat_dim 42 | 43 | # query and key mapping 44 | for i in range(self.s_len): 45 | setattr(self, 'query_'+str(i), MLPEmbed(feat_dim, feat_dim//factor)) 46 | for i in range(self.t_len): 47 | setattr(self, 'key_'+str(i), MLPEmbed(feat_dim, feat_dim//factor)) 48 | 49 | for i in range(self.s_len): 50 | for j in range(self.t_len): 51 | setattr(self, 'regressor'+str(i)+str(j), Proj(s_n[i], t_n[j])) 52 | 53 | def forward(self, feat_s, feat_t): 54 | 55 | sim_s = list(range(self.s_len)) 56 | sim_t = list(range(self.t_len)) 57 | bsz = self.feat_dim 58 | 59 | # similarity matrix 60 | for i in range(self.s_len): 61 | sim_temp = feat_s[i].reshape(bsz, -1) 62 | sim_s[i] = torch.matmul(sim_temp, sim_temp.t()) 63 | for i in range(self.t_len): 64 | sim_temp = feat_t[i].reshape(bsz, -1) 65 | sim_t[i] = torch.matmul(sim_temp, sim_temp.t()) 66 | 67 | # calculate student query 68 | proj_query = self.query_0(sim_s[0]) 69 | proj_query = proj_query[:, None, :] 70 | for i in range(1, self.s_len): 71 | temp_proj_query = getattr(self, 'query_'+str(i))(sim_s[i]) 72 | proj_query = torch.cat([proj_query, temp_proj_query[:, None, :]], 1) 73 | 74 | # calculate teacher key 75 | proj_key = self.key_0(sim_t[0]) 76 | proj_key = proj_key[:, :, None] 77 | for i in range(1, self.t_len): 78 | temp_proj_key = getattr(self, 'key_'+str(i))(sim_t[i]) 79 | proj_key = torch.cat([proj_key, temp_proj_key[:, :, None]], 2) 80 | 81 | # attention weight: batch_size X No. stu feature X No.tea feature 82 | energy = torch.bmm(proj_query, proj_key)/self.soft 83 | attention = F.softmax(energy, dim = -1) 84 | 85 | # feature dimension alignment 86 | proj_value_stu = [] 87 | value_tea = [] 88 | for i in range(self.s_len): 89 | proj_value_stu.append([]) 90 | value_tea.append([]) 91 | for j in range(self.t_len): 92 | s_H, t_H = feat_s[i].shape[2], feat_t[j].shape[2] 93 | if s_H > t_H: 94 | source = F.adaptive_avg_pool2d(feat_s[i], (t_H, t_H)) 95 | target = feat_t[j] 96 | elif s_H <= t_H: 97 | source = feat_s[i] 98 | target = F.adaptive_avg_pool2d(feat_t[j], (s_H, s_H)) 99 | 100 | proj_value_stu[i].append(getattr(self, 'regressor'+str(i)+str(j))(source)) 101 | value_tea[i].append(target) 102 | 103 | return proj_value_stu, value_tea, attention 104 | 105 | class Proj(nn.Module): 106 | """feature dimension alignment by 1x1, 3x3, 1x1 convolutions""" 107 | def __init__(self, num_input_channels=1024, num_target_channels=128): 108 | super(Proj, self).__init__() 109 | self.num_mid_channel = 2 * num_target_channels 110 | 111 | def conv1x1(in_channels, out_channels, stride=1): 112 | return nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, stride=stride, bias=False) 113 | def conv3x3(in_channels, out_channels, stride=1): 114 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False) 115 | 116 | self.regressor = nn.Sequential( 117 | conv1x1(num_input_channels, self.num_mid_channel), 118 | nn.BatchNorm2d(self.num_mid_channel), 119 | nn.ReLU(inplace=True), 120 | conv3x3(self.num_mid_channel, self.num_mid_channel), 121 | nn.BatchNorm2d(self.num_mid_channel), 122 | nn.ReLU(inplace=True), 123 | conv1x1(self.num_mid_channel, num_target_channels), 124 | ) 125 | 126 | def forward(self, x): 127 | x = self.regressor(x) 128 | return x 129 | 130 | class MLPEmbed(nn.Module): 131 | """non-linear mapping for attention calculation""" 132 | def __init__(self, dim_in=1024, dim_out=128): 133 | super(MLPEmbed, self).__init__() 134 | self.linear1 = nn.Linear(dim_in, 2 * dim_out) 135 | self.relu = nn.ReLU(inplace=True) 136 | self.linear2 = nn.Linear(2 * dim_out, dim_out) 137 | self.l2norm = Normalize(2) 138 | self.regressor = nn.Sequential( 139 | nn.Linear(dim_in, 2 * dim_out), 140 | self.l2norm, 141 | nn.ReLU(inplace=True), 142 | nn.Linear(2 * dim_out, dim_out), 143 | self.l2norm, 144 | ) 145 | 146 | def forward(self, x): 147 | x = x.view(x.shape[0], -1) 148 | x = self.relu(self.linear1(x)) 149 | x = self.l2norm(self.linear2(x)) 150 | 151 | return x 152 | 153 | class Normalize(nn.Module): 154 | """normalization layer""" 155 | def __init__(self, power=2): 156 | super(Normalize, self).__init__() 157 | self.power = power 158 | 159 | def forward(self, x): 160 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 161 | out = x.div(norm) 162 | return out 163 | 164 | class SRRL(nn.Module): 165 | """ICLR-2021: Knowledge Distillation via Softmax Regression Representation Learning""" 166 | def __init__(self, *, s_n, t_n): 167 | super(SRRL, self).__init__() 168 | 169 | def conv1x1(in_channels, out_channels, stride=1): 170 | return nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, stride=stride, bias=False) 171 | 172 | setattr(self, 'transfer', nn.Sequential( 173 | conv1x1(s_n, t_n), 174 | nn.BatchNorm2d(t_n), 175 | nn.ReLU(inplace=True), 176 | )) 177 | 178 | def forward(self, feat_s, cls_t): 179 | 180 | feat_s = feat_s.unsqueeze(-1).unsqueeze(-1) 181 | temp_feat = self.transfer(feat_s) 182 | trans_feat_s = temp_feat.view(temp_feat.size(0), -1) 183 | 184 | pred_feat_s=cls_t(trans_feat_s) 185 | 186 | return trans_feat_s, pred_feat_s 187 | 188 | class SimKD(nn.Module): 189 | """CVPR-2022: Knowledge Distillation with the Reused Teacher Classifier""" 190 | def __init__(self, *, s_n, t_n, factor=2): 191 | super(SimKD, self).__init__() 192 | 193 | self.avg_pool = nn.AdaptiveAvgPool2d((1,1)) 194 | 195 | def conv1x1(in_channels, out_channels, stride=1): 196 | return nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, stride=stride, bias=False) 197 | def conv3x3(in_channels, out_channels, stride=1, groups=1): 198 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False, groups=groups) 199 | 200 | # A bottleneck design to reduce extra parameters 201 | setattr(self, 'transfer', nn.Sequential( 202 | conv1x1(s_n, t_n//factor), 203 | nn.BatchNorm2d(t_n//factor), 204 | nn.ReLU(inplace=True), 205 | conv3x3(t_n//factor, t_n//factor), 206 | # depthwise convolution 207 | #conv3x3(t_n//factor, t_n//factor, groups=t_n//factor), 208 | nn.BatchNorm2d(t_n//factor), 209 | nn.ReLU(inplace=True), 210 | conv1x1(t_n//factor, t_n), 211 | nn.BatchNorm2d(t_n), 212 | nn.ReLU(inplace=True), 213 | )) 214 | 215 | def forward(self, feat_s, feat_t, cls_t): 216 | 217 | # Spatial Dimension Alignment 218 | s_H, t_H = feat_s.shape[2], feat_t.shape[2] 219 | if s_H > t_H: 220 | source = F.adaptive_avg_pool2d(feat_s, (t_H, t_H)) 221 | target = feat_t 222 | else: 223 | source = feat_s 224 | target = F.adaptive_avg_pool2d(feat_t, (s_H, s_H)) 225 | 226 | trans_feat_t=target 227 | 228 | # Channel Alignment 229 | trans_feat_s = getattr(self, 'transfer')(source) 230 | 231 | # Prediction via Teacher Classifier 232 | temp_feat = self.avg_pool(trans_feat_s) 233 | temp_feat = temp_feat.view(temp_feat.size(0), -1) 234 | pred_feat_s = cls_t(temp_feat) 235 | 236 | return trans_feat_s, trans_feat_t, pred_feat_s -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Three FC layers of VGG-ImageNet are replaced with single one, 3 | thus the total layer number should be reduced by two on CIFAR-100. 4 | For example, the actual number of layers for VGG-8 is 6. 5 | 6 | VGG for CIFAR10. FC layers are removed. 7 | (c) YANG, Wei 8 | ''' 9 | import math 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | __all__ = [ 14 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 15 | 'vgg19_bn', 'vgg19', 16 | ] 17 | 18 | class VGG(nn.Module): 19 | 20 | def __init__(self, cfg, batch_norm=False, num_classes=1000): 21 | super(VGG, self).__init__() 22 | self.block0 = self._make_layers(cfg[0], batch_norm, 3) 23 | self.block1 = self._make_layers(cfg[1], batch_norm, cfg[0][-1]) 24 | self.block2 = self._make_layers(cfg[2], batch_norm, cfg[1][-1]) 25 | self.block3 = self._make_layers(cfg[3], batch_norm, cfg[2][-1]) 26 | self.block4 = self._make_layers(cfg[4], batch_norm, cfg[3][-1]) 27 | 28 | self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2) 29 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 30 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 31 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 32 | self.pool4 = nn.AdaptiveAvgPool2d((1, 1)) 33 | # self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 34 | self.relu = nn.ReLU(inplace=True) 35 | 36 | self.classifier = nn.Linear(512, num_classes) 37 | self._initialize_weights() 38 | 39 | def get_feat_modules(self): 40 | feat_m = nn.ModuleList([]) 41 | feat_m.append(self.block0) 42 | feat_m.append(self.pool0) 43 | feat_m.append(self.block1) 44 | feat_m.append(self.pool1) 45 | feat_m.append(self.block2) 46 | feat_m.append(self.pool2) 47 | feat_m.append(self.block3) 48 | feat_m.append(self.pool3) 49 | feat_m.append(self.block4) 50 | feat_m.append(self.pool4) 51 | feat_m.append(self.classifier) 52 | return feat_m 53 | 54 | def forward(self, x, is_feat=False): 55 | h = x.shape[2] 56 | x = F.relu(self.block0(x)) 57 | f0 = x 58 | x = self.pool0(x) 59 | x = self.block1(x) 60 | x = self.relu(x) 61 | f1 = x 62 | x = self.pool1(x) 63 | x = self.block2(x) 64 | x = self.relu(x) 65 | f2 = x 66 | x = self.pool2(x) 67 | x = self.block3(x) 68 | x = self.relu(x) 69 | f3 = x 70 | if h == 64: 71 | x = self.pool3(x) 72 | x = self.block4(x) 73 | x = self.relu(x) 74 | f4 = x 75 | x = self.pool4(x) 76 | x = x.view(x.size(0), -1) 77 | f5 = x 78 | x = self.classifier(x) 79 | 80 | if is_feat: 81 | return [f0, f1, f2, f3, f4, f5], x 82 | else: 83 | return x 84 | 85 | @staticmethod 86 | def _make_layers(cfg, batch_norm=False, in_channels=3): 87 | layers = [] 88 | for v in cfg: 89 | if v == 'M': 90 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 91 | else: 92 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 93 | if batch_norm: 94 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 95 | else: 96 | layers += [conv2d, nn.ReLU(inplace=True)] 97 | in_channels = v 98 | layers = layers[:-1] 99 | return nn.Sequential(*layers) 100 | 101 | def _initialize_weights(self): 102 | for m in self.modules(): 103 | if isinstance(m, nn.Conv2d): 104 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 105 | m.weight.data.normal_(0, math.sqrt(2. / n)) 106 | if m.bias is not None: 107 | m.bias.data.zero_() 108 | elif isinstance(m, nn.BatchNorm2d): 109 | m.weight.data.fill_(1) 110 | m.bias.data.zero_() 111 | elif isinstance(m, nn.Linear): 112 | n = m.weight.size(1) 113 | m.weight.data.normal_(0, 0.01) 114 | m.bias.data.zero_() 115 | 116 | cfg = { 117 | 'A': [[64], [128], [256, 256], [512, 512], [512, 512]], 118 | 'B': [[64, 64], [128, 128], [256, 256], [512, 512], [512, 512]], 119 | 'D': [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]], 120 | 'E': [[64, 64], [128, 128], [256, 256, 256, 256], [512, 512, 512, 512], [512, 512, 512, 512]], 121 | 'S': [[64], [128], [256], [512], [512]], 122 | } 123 | 124 | def vgg8(**kwargs): 125 | """VGG 8-layer model (configuration "S")""" 126 | model = VGG(cfg['S'], **kwargs) 127 | return model 128 | 129 | 130 | def vgg8_bn(**kwargs): 131 | """VGG 8-layer model (configuration "S")""" 132 | model = VGG(cfg['S'], batch_norm=True, **kwargs) 133 | return model 134 | 135 | 136 | def vgg11(**kwargs): 137 | """VGG 11-layer model (configuration "A")""" 138 | model = VGG(cfg['A'], **kwargs) 139 | return model 140 | 141 | 142 | def vgg11_bn(**kwargs): 143 | """VGG 11-layer model (configuration "A") with batch normalization""" 144 | model = VGG(cfg['A'], batch_norm=True, **kwargs) 145 | return model 146 | 147 | 148 | def vgg13(**kwargs): 149 | """VGG 13-layer model (configuration "B")""" 150 | model = VGG(cfg['B'], **kwargs) 151 | return model 152 | 153 | 154 | def vgg13_bn(**kwargs): 155 | """VGG 13-layer model (configuration "B") with batch normalization""" 156 | model = VGG(cfg['B'], batch_norm=True, **kwargs) 157 | return model 158 | 159 | 160 | def vgg16(**kwargs): 161 | """VGG 16-layer model (configuration "D")""" 162 | model = VGG(cfg['D'], **kwargs) 163 | return model 164 | 165 | 166 | def vgg16_bn(**kwargs): 167 | """VGG 16-layer model (configuration "D") with batch normalization""" 168 | model = VGG(cfg['D'], batch_norm=True, **kwargs) 169 | return model 170 | 171 | 172 | def vgg19(**kwargs): 173 | """VGG 19-layer model (configuration "E")""" 174 | model = VGG(cfg['E'], **kwargs) 175 | return model 176 | 177 | 178 | def vgg19_bn(**kwargs): 179 | """VGG 19-layer model (configuration 'E') with batch normalization""" 180 | model = VGG(cfg['E'], batch_norm=True, **kwargs) 181 | return model 182 | 183 | 184 | if __name__ == '__main__': 185 | import torch 186 | 187 | x = torch.randn(2, 3, 32, 32) 188 | net = vgg13_bn(num_classes=100) 189 | feats, logit = net(x, is_feat=True) 190 | 191 | for f in feats: 192 | print(f.shape, f.min().item()) 193 | print(logit.shape) 194 | 195 | num_params_stu = (sum(p.numel() for p in net.parameters())/1000000.0) 196 | print('Total params_stu: {:.3f} M'.format(num_params_stu)) 197 | -------------------------------------------------------------------------------- /scripts/run_distill.sh: -------------------------------------------------------------------------------- 1 | # sample scripts for running various knowledge distillation approaches 2 | # we use resnet32x4 and resnet8x4 as an example 3 | 4 | # CIFAR 5 | # KD 6 | python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill kd --model_s resnet8x4 -c 1 -d 1 -b 0 --trial 0 --gpu_id 0 7 | # FitNet 8 | python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill hint --model_s resnet8x4 -c 1 -d 1 -b 100 --trial 0 --gpu_id 0 9 | # AT 10 | python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill attention --model_s resnet8x4 -c 1 -d 1 -b 1000 --trial 0 --gpu_id 0 11 | # SP 12 | python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill similarity --model_s resnet8x4 -c 1 -d 1 -b 3000 --trial 0 --gpu_id 0 13 | # VID 14 | python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill vid --model_s resnet8x4 -c 1 -d 1 -b 1 --trial 0 --gpu_id 0 15 | # CRD 16 | python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill crd --model_s resnet8x4 -c 1 -d 1 -b 0.8 --trial 0 --gpu_id 0 17 | # SemCKD 18 | python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill semckd --model_s resnet8x4 -c 1 -d 1 -b 400 --trial 0 --gpu_id 0 19 | # SRRL 20 | python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill srrl --model_s resnet8x4 -c 1 -d 1 -b 1 --trial 0 --gpu_id 0 21 | # SimKD 22 | python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill simkd --model_s resnet8x4 -c 0 -d 0 -b 1 --trial 0 --gpu_id 0 23 | 24 | # ImageNets 25 | python train_student.py --path_t './save/teachers/models/ResNet50_vanilla/ResNet50_best.pth' --batch_size 256 --epochs 120 --dataset imagenet --model_s ResNet18 --distill simkd -c 0 -d 0 -b 1 --learning_rate 0.1 --lr_decay_epochs 30,60,90 --weight_decay 1e-4 --num_workers 32 --gpu_id 0,1,2,3 --dist-url tcp://127.0.0.1:23344 --multiprocessing-distributed --dali gpu --trial 0 26 | -------------------------------------------------------------------------------- /scripts/run_vanilla.sh: -------------------------------------------------------------------------------- 1 | # sample scripts for training vanilla teacher/student models 2 | 3 | # CIFAR 4 | python train_teacher.py --model resnet8x4 --trial 0 --gpu_id 0 5 | 6 | python train_teacher.py --model resnet32x4 --trial 0 --gpu_id 0 7 | 8 | python train_teacher.py --model resnet110 --trial 0 --gpu_id 0 9 | 10 | python train_teacher.py --model resnet116 --trial 0 --gpu_id 0 11 | 12 | python train_teacher.py --model resnet110x2 --trial 0 --gpu_id 0 13 | 14 | python train_teacher.py --model vgg8 --trial 0 --gpu_id 0 15 | 16 | python train_teacher.py --model vgg13 --trial 0 --gpu_id 0 17 | 18 | python train_teacher.py --model ShuffleV1 --trial 0 --gpu_id 0 19 | 20 | python train_teacher.py --model ShuffleV2 --trial 0 --gpu_id 0 21 | 22 | python train_teacher.py --model ShuffleV2_1_5 --trial 0 --gpu_id 0 23 | 24 | python train_teacher.py --model MobileNetV2 --trial 0 --gpu_id 0 25 | 26 | python train_teacher.py --model MobileNetV2_1_0 --trial 0 --gpu_id 0 27 | 28 | # WRN-40-1 29 | python train_teacher.py --model resnet38 --trial 0 --gpu_id 0 30 | # WRN-40-2 31 | python train_teacher.py --model resnet38x2 --trial 0 --gpu_id 0 32 | # WRN-16-2 33 | python train_teacher.py --model resnet14x2 --trial 0 --gpu_id 0 34 | # WRN-40-4 35 | python train_teacher.py --model resnet38x4 --trial 0 --gpu_id 0 36 | # WRN-16-4 37 | python train_teacher.py --model resnet14x4 --trial 0 --gpu_id 0 38 | 39 | # ImageNet 40 | python train_teacher.py --batch_size 256 --epochs 120 --dataset imagenet --model ResNet18 --learning_rate 0.1 --lr_decay_epochs 30,60,90 --weight_decay 1e-4 --num_workers 32 --gpu_id 0,1,2,3 --dist-url tcp://127.0.0.1:23333 --multiprocessing-distributed --dali gpu --trial 0 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /train_student.py: -------------------------------------------------------------------------------- 1 | """ 2 | the general training framework 3 | """ 4 | 5 | from __future__ import print_function 6 | 7 | import os 8 | import re 9 | import argparse 10 | import time 11 | 12 | import numpy 13 | import torch 14 | import torch.optim as optim 15 | import torch.multiprocessing as mp 16 | import torch.distributed as dist 17 | import torch.nn as nn 18 | import torch.backends.cudnn as cudnn 19 | import tensorboard_logger as tb_logger 20 | 21 | from models import model_dict 22 | from models.util import ConvReg, SelfA, SRRL, SimKD 23 | 24 | from dataset.cifar100 import get_cifar100_dataloaders, get_cifar100_dataloaders_sample 25 | from dataset.imagenet import get_imagenet_dataloader, get_dataloader_sample 26 | from dataset.imagenet_dali import get_dali_data_loader 27 | 28 | from helper.loops import train_distill as train, validate_vanilla, validate_distill 29 | from helper.util import save_dict_to_json, reduce_tensor, adjust_learning_rate 30 | 31 | from crd.criterion import CRDLoss 32 | from distiller_zoo import DistillKL, HintLoss, Attention, Similarity, VIDLoss, SemCKDLoss 33 | 34 | split_symbol = '~' if os.name == 'nt' else ':' 35 | 36 | def parse_option(): 37 | 38 | parser = argparse.ArgumentParser('argument for training') 39 | 40 | # basic 41 | parser.add_argument('--print_freq', type=int, default=200, help='print frequency') 42 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size') 43 | parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use') 44 | parser.add_argument('--epochs', type=int, default=240, help='number of training epochs') 45 | parser.add_argument('--gpu_id', type=str, default='0', help='id(s) for CUDA_VISIBLE_DEVICES') 46 | 47 | # optimization 48 | parser.add_argument('--learning_rate', type=float, default=0.05, help='learning rate') 49 | parser.add_argument('--lr_decay_epochs', type=str, default='150,180,210', help='where to decay lr, can be a list') 50 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate') 51 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay') 52 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 53 | 54 | # dataset and model 55 | parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar100', 'imagenet'], help='dataset') 56 | parser.add_argument('--model_s', type=str, default='resnet8x4') 57 | parser.add_argument('--path_t', type=str, default=None, help='teacher model snapshot') 58 | 59 | # distillation 60 | parser.add_argument('--trial', type=str, default='1', help='trial id') 61 | parser.add_argument('--kd_T', type=float, default=4, help='temperature for KD distillation') 62 | parser.add_argument('--distill', type=str, default='kd', choices=['kd', 'hint', 'attention', 'similarity', 'vid', 63 | 'crd', 'semckd','srrl', 'simkd']) 64 | parser.add_argument('-c', '--cls', type=float, default=1.0, help='weight for classification') 65 | parser.add_argument('-d', '--div', type=float, default=1.0, help='weight balance for KD') 66 | parser.add_argument('-b', '--beta', type=float, default=0.0, help='weight balance for other losses') 67 | parser.add_argument('-f', '--factor', type=int, default=2, help='factor size of SimKD') 68 | parser.add_argument('-s', '--soft', type=float, default=1.0, help='attention scale of SemCKD') 69 | 70 | # hint layer 71 | parser.add_argument('--hint_layer', default=1, type=int, choices=[0, 1, 2, 3, 4]) 72 | 73 | # NCE distillation 74 | parser.add_argument('--feat_dim', default=128, type=int, help='feature dimension') 75 | parser.add_argument('--mode', default='exact', type=str, choices=['exact', 'relax']) 76 | parser.add_argument('--nce_k', default=16384, type=int, help='number of negative samples for NCE') 77 | parser.add_argument('--nce_t', default=0.07, type=float, help='temperature parameter for softmax') 78 | parser.add_argument('--nce_m', default=0.5, type=float, help='momentum for non-parametric updates') 79 | 80 | # multiprocessing 81 | parser.add_argument('--dali', type=str, choices=['cpu', 'gpu'], default=None) 82 | parser.add_argument('--multiprocessing-distributed', action='store_true', 83 | help='Use multi-processing distributed training to launch ' 84 | 'N processes per node, which has N GPUs. This is the ' 85 | 'fastest way to use PyTorch for either single node or ' 86 | 'multi node data parallel training') 87 | parser.add_argument('--dist-url', default='tcp://127.0.0.1:23451', type=str, 88 | help='url used to set up distributed training') 89 | parser.add_argument('--deterministic', action='store_true', help='Make results reproducible') 90 | parser.add_argument('--skip-validation', action='store_true', help='Skip validation of teacher') 91 | 92 | opt = parser.parse_args() 93 | 94 | # set different learning rates for these MobileNet/ShuffleNet models 95 | if opt.model_s in ['MobileNetV2', 'MobileNetV2_1_0', 'ShuffleV1', 'ShuffleV2', 'ShuffleV2_1_5']: 96 | opt.learning_rate = 0.01 97 | 98 | # set the path of model and tensorboard 99 | opt.model_path = './save/students/models' 100 | opt.tb_path = './save/students/tensorboard' 101 | 102 | iterations = opt.lr_decay_epochs.split(',') 103 | opt.lr_decay_epochs = list([]) 104 | for it in iterations: 105 | opt.lr_decay_epochs.append(int(it)) 106 | 107 | opt.model_t = get_teacher_name(opt.path_t) 108 | 109 | model_name_template = split_symbol.join(['S', '{}_T', '{}_{}_{}_r', '{}_a', '{}_b', '{}_{}']) 110 | opt.model_name = model_name_template.format(opt.model_s, opt.model_t, opt.dataset, opt.distill, 111 | opt.cls, opt.div, opt.beta, opt.trial) 112 | 113 | if opt.dali is not None: 114 | opt.model_name += '_dali:' + opt.dali 115 | 116 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name) 117 | if not os.path.isdir(opt.tb_folder): 118 | os.makedirs(opt.tb_folder) 119 | 120 | opt.save_folder = os.path.join(opt.model_path, opt.model_name) 121 | if not os.path.isdir(opt.save_folder): 122 | os.makedirs(opt.save_folder) 123 | 124 | return opt 125 | 126 | def get_teacher_name(model_path): 127 | """parse teacher name""" 128 | directory = model_path.split('/')[-2] 129 | pattern = ''.join(['S', split_symbol, '(.+)', '_T', split_symbol]) 130 | name_match = re.match(pattern, directory) 131 | if name_match: 132 | return name_match[1] 133 | segments = directory.split('_') 134 | if segments[0] == 'wrn': 135 | return segments[0] + '_' + segments[1] + '_' + segments[2] 136 | return segments[0] 137 | 138 | 139 | def load_teacher(model_path, n_cls, gpu=None, opt=None): 140 | print('==> loading teacher model') 141 | model_t = get_teacher_name(model_path) 142 | model = model_dict[model_t](num_classes=n_cls) 143 | map_location = None if gpu is None else {'cuda:0': 'cuda:%d' % (gpu if opt.multiprocessing_distributed else 0)} 144 | model.load_state_dict(torch.load(model_path, map_location=map_location)['model']) 145 | print('==> done') 146 | return model 147 | 148 | 149 | best_acc = 0 150 | total_time = time.time() 151 | def main(): 152 | 153 | opt = parse_option() 154 | 155 | # ASSIGN CUDA_ID 156 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id 157 | 158 | ngpus_per_node = torch.cuda.device_count() 159 | opt.ngpus_per_node = ngpus_per_node 160 | if opt.multiprocessing_distributed: 161 | # Since we have ngpus_per_node processes per node, the total world_size 162 | # needs to be adjusted accordingly 163 | world_size = 1 164 | opt.world_size = ngpus_per_node * world_size 165 | # Use torch.multiprocessing.spawn to launch distributed processes: the 166 | # main_worker process function 167 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt)) 168 | else: 169 | main_worker(None if ngpus_per_node > 1 else opt.gpu_id, ngpus_per_node, opt) 170 | 171 | def main_worker(gpu, ngpus_per_node, opt): 172 | global best_acc, total_time 173 | opt.gpu = int(gpu) 174 | opt.gpu_id = int(gpu) 175 | 176 | if opt.gpu is not None: 177 | print("Use GPU: {} for training".format(opt.gpu)) 178 | 179 | if opt.multiprocessing_distributed: 180 | # Only one node now. 181 | opt.rank = gpu 182 | dist_backend = 'nccl' 183 | dist.init_process_group(backend=dist_backend, init_method=opt.dist_url, 184 | world_size=opt.world_size, rank=opt.rank) 185 | opt.batch_size = int(opt.batch_size / ngpus_per_node) 186 | opt.num_workers = int((opt.num_workers + ngpus_per_node - 1) / ngpus_per_node) 187 | 188 | if opt.deterministic: 189 | torch.manual_seed(12345) 190 | cudnn.deterministic = True 191 | cudnn.benchmark = False 192 | numpy.random.seed(12345) 193 | 194 | # model 195 | n_cls = { 196 | 'cifar100': 100, 197 | 'imagenet': 1000, 198 | }.get(opt.dataset, None) 199 | 200 | model_t = load_teacher(opt.path_t, n_cls, opt.gpu, opt) 201 | try: 202 | model_s = model_dict[opt.model_s](num_classes=n_cls) 203 | except KeyError: 204 | print("This model is not supported.") 205 | 206 | if opt.dataset == 'cifar100': 207 | data = torch.randn(2, 3, 32, 32) 208 | elif opt.dataset == 'imagenet': 209 | data = torch.randn(2, 3, 224, 224) 210 | 211 | model_t.eval() 212 | model_s.eval() 213 | feat_t, _ = model_t(data, is_feat=True) 214 | feat_s, _ = model_s(data, is_feat=True) 215 | 216 | module_list = nn.ModuleList([]) 217 | module_list.append(model_s) 218 | trainable_list = nn.ModuleList([]) 219 | trainable_list.append(model_s) 220 | 221 | criterion_cls = nn.CrossEntropyLoss() 222 | criterion_div = DistillKL(opt.kd_T) 223 | if opt.distill == 'kd': 224 | criterion_kd = DistillKL(opt.kd_T) 225 | elif opt.distill == 'hint': 226 | criterion_kd = HintLoss() 227 | regress_s = ConvReg(feat_s[opt.hint_layer].shape, feat_t[opt.hint_layer].shape) 228 | module_list.append(regress_s) 229 | trainable_list.append(regress_s) 230 | elif opt.distill == 'attention': 231 | criterion_kd = Attention() 232 | elif opt.distill == 'similarity': 233 | criterion_kd = Similarity() 234 | elif opt.distill == 'vid': 235 | s_n = [f.shape[1] for f in feat_s[1:-1]] 236 | t_n = [f.shape[1] for f in feat_t[1:-1]] 237 | criterion_kd = nn.ModuleList( 238 | [VIDLoss(s, t, t) for s, t in zip(s_n, t_n)] 239 | ) 240 | # add this as some parameters in VIDLoss need to be updated 241 | trainable_list.append(criterion_kd) 242 | elif opt.distill == 'crd': 243 | opt.s_dim = feat_s[-1].shape[1] 244 | opt.t_dim = feat_t[-1].shape[1] 245 | if opt.dataset == 'cifar100': 246 | opt.n_data = 50000 247 | else: 248 | opt.n_data = 1281167 249 | criterion_kd = CRDLoss(opt) 250 | module_list.append(criterion_kd.embed_s) 251 | module_list.append(criterion_kd.embed_t) 252 | trainable_list.append(criterion_kd.embed_s) 253 | trainable_list.append(criterion_kd.embed_t) 254 | elif opt.distill == 'semckd': 255 | s_n = [f.shape[1] for f in feat_s[1:-1]] 256 | t_n = [f.shape[1] for f in feat_t[1:-1]] 257 | criterion_kd = SemCKDLoss() 258 | self_attention = SelfA(opt.batch_size, s_n, t_n, opt.soft) 259 | module_list.append(self_attention) 260 | trainable_list.append(self_attention) 261 | elif opt.distill == 'srrl': 262 | s_n = feat_s[-1].shape[1] 263 | t_n = feat_t[-1].shape[1] 264 | model_fmsr = SRRL(s_n= s_n, t_n=t_n) 265 | criterion_kd = nn.MSELoss() 266 | module_list.append(model_fmsr) 267 | trainable_list.append(model_fmsr) 268 | elif opt.distill == 'simkd': 269 | s_n = feat_s[-2].shape[1] 270 | t_n = feat_t[-2].shape[1] 271 | model_simkd = SimKD(s_n= s_n, t_n=t_n, factor=opt.factor) 272 | criterion_kd = nn.MSELoss() 273 | module_list.append(model_simkd) 274 | trainable_list.append(model_simkd) 275 | else: 276 | raise NotImplementedError(opt.distill) 277 | 278 | criterion_list = nn.ModuleList([]) 279 | criterion_list.append(criterion_cls) # classification loss 280 | criterion_list.append(criterion_div) # KL divergence loss, original knowledge distillation 281 | criterion_list.append(criterion_kd) # other knowledge distillation loss 282 | 283 | module_list.append(model_t) 284 | 285 | optimizer = optim.SGD(trainable_list.parameters(), 286 | lr=opt.learning_rate, 287 | momentum=opt.momentum, 288 | weight_decay=opt.weight_decay) 289 | 290 | if torch.cuda.is_available(): 291 | # For multiprocessing distributed, DistributedDataParallel constructor 292 | # should always set the single device scope, otherwise, 293 | # DistributedDataParallel will use all available devices. 294 | if opt.multiprocessing_distributed: 295 | if opt.gpu is not None: 296 | torch.cuda.set_device(opt.gpu) 297 | module_list.cuda(opt.gpu) 298 | distributed_modules = [] 299 | for module in module_list: 300 | DDP = torch.nn.parallel.DistributedDataParallel 301 | distributed_modules.append(DDP(module, device_ids=[opt.gpu])) 302 | module_list = distributed_modules 303 | criterion_list.cuda(opt.gpu) 304 | else: 305 | print('multiprocessing_distributed must be with a specifiec gpu id') 306 | else: 307 | criterion_list.cuda() 308 | module_list.cuda() 309 | if not opt.deterministic: 310 | cudnn.benchmark = True 311 | 312 | # dataloader 313 | if opt.dataset == 'cifar100': 314 | if opt.distill in ['crd']: 315 | train_loader, val_loader, n_data = get_cifar100_dataloaders_sample(batch_size=opt.batch_size, 316 | num_workers=opt.num_workers, 317 | k=opt.nce_k, 318 | mode=opt.mode) 319 | else: 320 | train_loader, val_loader = get_cifar100_dataloaders(batch_size=opt.batch_size, 321 | num_workers=opt.num_workers) 322 | elif opt.dataset == 'imagenet': 323 | if opt.dali is None: 324 | if opt.distill in ['crd']: 325 | train_loader, val_loader, n_data, _, train_sampler = get_dataloader_sample(dataset=opt.dataset, batch_size=opt.batch_size, 326 | num_workers=opt.num_workers, 327 | is_sample=True, 328 | k=opt.nce_k, 329 | multiprocessing_distributed=opt.multiprocessing_distributed) 330 | else: 331 | train_loader, val_loader, train_sampler = get_imagenet_dataloader(dataset=opt.dataset, batch_size=opt.batch_size, 332 | num_workers=opt.num_workers, 333 | multiprocessing_distributed=opt.multiprocessing_distributed) 334 | else: 335 | train_loader, val_loader = get_dali_data_loader(opt) 336 | else: 337 | raise NotImplementedError(opt.dataset) 338 | 339 | if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0: 340 | logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) 341 | 342 | if not opt.skip_validation: 343 | # validate teacher accuracy 344 | teacher_acc, _, _ = validate_vanilla(val_loader, model_t, criterion_cls, opt) 345 | 346 | if opt.dali is not None: 347 | val_loader.reset() 348 | 349 | if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0: 350 | print('teacher accuracy: ', teacher_acc) 351 | else: 352 | print('Skipping teacher validation.') 353 | 354 | # routine 355 | for epoch in range(1, opt.epochs + 1): 356 | torch.cuda.empty_cache() 357 | if opt.multiprocessing_distributed: 358 | if opt.dali is None: 359 | train_sampler.set_epoch(epoch) 360 | 361 | adjust_learning_rate(epoch, opt, optimizer) 362 | print("==> training...") 363 | 364 | time1 = time.time() 365 | train_acc, train_acc_top5, train_loss = train(epoch, train_loader, module_list, criterion_list, optimizer, opt) 366 | time2 = time.time() 367 | 368 | if opt.multiprocessing_distributed: 369 | metrics = torch.tensor([train_acc, train_acc_top5, train_loss]).cuda(opt.gpu, non_blocking=True) 370 | reduced = reduce_tensor(metrics, opt.world_size if 'world_size' in opt else 1) 371 | train_acc, train_acc_top5, train_loss = reduced.tolist() 372 | 373 | if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0: 374 | print(' * Epoch {}, GPU {}, Acc@1 {:.3f}, Acc@5 {:.3f}, Time {:.2f}'.format(epoch, opt.gpu, train_acc, train_acc_top5, time2 - time1)) 375 | 376 | logger.log_value('train_acc', train_acc, epoch) 377 | logger.log_value('train_loss', train_loss, epoch) 378 | 379 | print('GPU %d validating' % (opt.gpu)) 380 | test_acc, test_acc_top5, test_loss = validate_distill(val_loader, module_list, criterion_cls, opt) 381 | 382 | if opt.dali is not None: 383 | train_loader.reset() 384 | val_loader.reset() 385 | 386 | if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0: 387 | print(' ** Acc@1 {:.3f}, Acc@5 {:.3f}'.format(test_acc, test_acc_top5)) 388 | 389 | logger.log_value('test_acc', test_acc, epoch) 390 | logger.log_value('test_loss', test_loss, epoch) 391 | logger.log_value('test_acc_top5', test_acc_top5, epoch) 392 | 393 | # save the best model 394 | if test_acc > best_acc: 395 | best_acc = test_acc 396 | state = { 397 | 'epoch': epoch, 398 | 'model': model_s.state_dict(), 399 | 'best_acc': best_acc, 400 | } 401 | if opt.distill == 'simkd': 402 | state['proj'] = trainable_list[-1].state_dict() 403 | save_file = os.path.join(opt.save_folder, '{}_best.pth'.format(opt.model_s)) 404 | 405 | test_merics = {'test_loss': test_loss, 406 | 'test_acc': test_acc, 407 | 'test_acc_top5': test_acc_top5, 408 | 'epoch': epoch} 409 | 410 | save_dict_to_json(test_merics, os.path.join(opt.save_folder, "test_best_metrics.json")) 411 | print('saving the best model!') 412 | torch.save(state, save_file) 413 | 414 | if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0: 415 | # This best accuracy is only for printing purpose. 416 | print('best accuracy:', best_acc) 417 | 418 | # save parameters 419 | save_state = {k: v for k, v in opt._get_kwargs()} 420 | # No. parameters(M) 421 | num_params = (sum(p.numel() for p in model_s.parameters())/1000000.0) 422 | save_state['Total params'] = num_params 423 | save_state['Total time'] = (time.time() - total_time)/3600.0 424 | params_json_path = os.path.join(opt.save_folder, "parameters.json") 425 | save_dict_to_json(save_state, params_json_path) 426 | 427 | if __name__ == '__main__': 428 | main() 429 | -------------------------------------------------------------------------------- /train_teacher.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training a single model (student or teacher) 3 | """ 4 | 5 | import os 6 | import argparse 7 | import time 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.optim as optim 12 | import torch.distributed as dist 13 | import torch.multiprocessing as mp 14 | import torch.backends.cudnn as cudnn 15 | import tensorboard_logger as tb_logger 16 | 17 | from models import model_dict 18 | from dataset.cifar100 import get_cifar100_dataloaders 19 | from dataset.imagenet import get_imagenet_dataloader 20 | from dataset.imagenet_dali import get_dali_data_loader 21 | from helper.util import save_dict_to_json, reduce_tensor, adjust_learning_rate 22 | from helper.loops import train_vanilla as train, validate_vanilla 23 | 24 | def parse_option(): 25 | 26 | parser = argparse.ArgumentParser('argument for training') 27 | 28 | # baisc 29 | parser.add_argument('--print_freq', type=int, default=200, help='print frequency') 30 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size') 31 | parser.add_argument('--num_workers', type=int, default=8, help='num_workers') 32 | parser.add_argument('--epochs', type=int, default=240, help='number of training epochs') 33 | parser.add_argument('--gpu_id', type=str, default='0', help='id(s) for CUDA_VISIBLE_DEVICES') 34 | 35 | # optimization 36 | parser.add_argument('--learning_rate', type=float, default=0.05, help='learning rate') 37 | parser.add_argument('--lr_decay_epochs', type=str, default='150,180,210', help='where to decay lr, can be a list') 38 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate') 39 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay') 40 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 41 | 42 | # dataset 43 | parser.add_argument('--model', type=str, default='resnet32x4') 44 | parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar100', 'imagenet'], help='dataset') 45 | parser.add_argument('-t', '--trial', type=str, default='0', help='the experiment id') 46 | parser.add_argument('--dali', type=str, choices=['cpu', 'gpu'], default=None) 47 | 48 | # multiprocessing 49 | parser.add_argument('--multiprocessing-distributed', action='store_true', 50 | help='Use multi-processing distributed training to launch ' 51 | 'N processes per node, which has N GPUs. This is the ' 52 | 'fastest way to use PyTorch for either single node or ' 53 | 'multi node data parallel training') 54 | parser.add_argument('--dist-url', default='tcp://127.0.0.1:23451', type=str, 55 | help='url used to set up distributed training') 56 | 57 | opt = parser.parse_args() 58 | 59 | # set different learning rates for these MobileNet/ShuffleNet models 60 | if opt.model in ['MobileNetV2', 'MobileNetV2_1_0', 'ShuffleV1', 'ShuffleV2', 'ShuffleV2_1_5']: 61 | opt.learning_rate = 0.01 62 | 63 | # set the path of model and tensorboard 64 | opt.model_path = './save/teachers/models' 65 | opt.tb_path = './save/teachers/tensorboard' 66 | 67 | iterations = opt.lr_decay_epochs.split(',') 68 | opt.lr_decay_epochs = list([]) 69 | for it in iterations: 70 | opt.lr_decay_epochs.append(int(it)) 71 | 72 | # set the model name 73 | opt.model_name = '{}_vanilla_{}_trial_{}'.format(opt.model, opt.dataset, opt.trial) 74 | if opt.dali is not None: 75 | opt.model_name += '_dali:' + opt.dali 76 | 77 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name) 78 | if not os.path.isdir(opt.tb_folder): 79 | os.makedirs(opt.tb_folder) 80 | 81 | opt.save_folder = os.path.join(opt.model_path, opt.model_name) 82 | if not os.path.isdir(opt.save_folder): 83 | os.makedirs(opt.save_folder) 84 | 85 | return opt 86 | 87 | best_acc = 0 88 | total_time = time.time() 89 | def main(): 90 | opt = parse_option() 91 | 92 | # ASSIGN CUDA_ID 93 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id 94 | 95 | ngpus_per_node = torch.cuda.device_count() 96 | opt.ngpus_per_node = ngpus_per_node 97 | if opt.multiprocessing_distributed: 98 | # Since we have ngpus_per_node processes per node, the total world_size 99 | # needs to be adjusted accordingly 100 | world_size = 1 101 | opt.world_size = ngpus_per_node * world_size 102 | # Use torch.multiprocessing.spawn to launch distributed processes: the 103 | # main_worker process function 104 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt)) 105 | else: 106 | main_worker(None if ngpus_per_node > 1 else opt.gpu_id, ngpus_per_node, opt) 107 | 108 | def main_worker(gpu, ngpus_per_node, opt): 109 | global best_acc, total_time 110 | opt.gpu = int(gpu) 111 | opt.gpu_id = int(gpu) 112 | 113 | if opt.gpu is not None: 114 | print("Use GPU: {} for training".format(opt.gpu)) 115 | 116 | if opt.multiprocessing_distributed: 117 | # Only one node now. 118 | opt.rank = int(gpu) 119 | dist_backend = 'nccl' 120 | dist.init_process_group(backend=dist_backend, init_method=opt.dist_url, 121 | world_size=opt.world_size, rank=opt.rank) 122 | 123 | # model 124 | n_cls = { 125 | 'cifar100': 100, 126 | 'imagenet': 1000, 127 | }.get(opt.dataset, None) 128 | 129 | try: 130 | model = model_dict[opt.model](num_classes=n_cls) 131 | except KeyError: 132 | print("This model is not supported.") 133 | 134 | # optimizer 135 | optimizer = optim.SGD(model.parameters(), 136 | lr=opt.learning_rate, 137 | momentum=opt.momentum, 138 | weight_decay=opt.weight_decay) 139 | criterion = nn.CrossEntropyLoss() 140 | 141 | if torch.cuda.is_available(): 142 | # For multiprocessing distributed, DistributedDataParallel constructor 143 | # should always set the single device scope, otherwise, 144 | # DistributedDataParallel will use all available devices. 145 | if opt.multiprocessing_distributed: 146 | if opt.gpu is not None: 147 | torch.cuda.set_device(opt.gpu) 148 | model = model.cuda(opt.gpu) 149 | criterion = criterion.cuda(opt.gpu) 150 | # When using a single GPU per process and per 151 | # DistributedDataParallel, we need to divide the batch size 152 | # ourselves based on the total number of GPUs we have 153 | opt.batch_size = int(opt.batch_size / ngpus_per_node) 154 | opt.num_workers = int((opt.num_workers + ngpus_per_node - 1) / ngpus_per_node) 155 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[opt.gpu]) 156 | else: 157 | print('multiprocessing_distributed must be with a specifiec gpu id') 158 | else: 159 | criterion = criterion.cuda() 160 | if torch.cuda.device_count() > 1: 161 | model = nn.DataParallel(model).cuda() 162 | else: 163 | model = model.cuda() 164 | 165 | 166 | cudnn.benchmark = True 167 | 168 | # dataloader 169 | if opt.dataset == 'cifar100': 170 | train_loader, val_loader = get_cifar100_dataloaders(batch_size=opt.batch_size, num_workers=opt.num_workers) 171 | elif opt.dataset == 'imagenet': 172 | if opt.dali is None: 173 | train_loader, val_loader, train_sampler = get_imagenet_dataloader( 174 | dataset = opt.dataset, 175 | batch_size=opt.batch_size, num_workers=opt.num_workers, 176 | multiprocessing_distributed=opt.multiprocessing_distributed) 177 | else: 178 | train_loader, val_loader = get_dali_data_loader(opt) 179 | else: 180 | raise NotImplementedError(opt.dataset) 181 | 182 | # tensorboard 183 | if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0: 184 | logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) 185 | 186 | # routine 187 | for epoch in range(1, opt.epochs + 1): 188 | if opt.multiprocessing_distributed: 189 | if opt.dali is None: 190 | train_sampler.set_epoch(epoch) 191 | # No test_sampler because epoch is random seed, not needed in sequential testing. 192 | 193 | adjust_learning_rate(epoch, opt, optimizer) 194 | print("==> training...") 195 | 196 | time1 = time.time() 197 | train_acc, train_acc_top5, train_loss = train(epoch, train_loader, model, criterion, optimizer, opt) 198 | time2 = time.time() 199 | 200 | if opt.multiprocessing_distributed: 201 | metrics = torch.tensor([train_acc, train_acc_top5, train_loss]).cuda(opt.gpu, non_blocking=True) 202 | reduced = reduce_tensor(metrics, opt.world_size if 'world_size' in opt else 1) 203 | train_acc, train_acc_top5, train_loss = reduced.tolist() 204 | 205 | if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0: 206 | print(' * Epoch {}, Acc@1 {:.3f}, Acc@5 {:.3f}, Time {:.2f}'.format(epoch, train_acc, train_acc_top5, time2 - time1)) 207 | 208 | logger.log_value('train_acc', train_acc, epoch) 209 | logger.log_value('train_loss', train_loss, epoch) 210 | 211 | test_acc, test_acc_top5, test_loss = validate_vanilla(val_loader, model, criterion, opt) 212 | 213 | if opt.dali is not None: 214 | train_loader.reset() 215 | val_loader.reset() 216 | 217 | if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0: 218 | print(' ** Acc@1 {:.3f}, Acc@5 {:.3f}'.format(test_acc, test_acc_top5)) 219 | 220 | logger.log_value('test_acc', test_acc, epoch) 221 | logger.log_value('test_acc_top5', test_acc_top5, epoch) 222 | logger.log_value('test_loss', test_loss, epoch) 223 | 224 | # save the best model 225 | if test_acc > best_acc: 226 | best_acc = test_acc 227 | state = { 228 | 'epoch': epoch, 229 | 'best_acc': best_acc, 230 | 'model': model.module.state_dict() if opt.multiprocessing_distributed else model.state_dict(), 231 | } 232 | save_file = os.path.join(opt.save_folder, '{}_best.pth'.format(opt.model)) 233 | 234 | test_merics = { 'test_loss': float('%.2f' % test_loss), 235 | 'test_acc': float('%.2f' % test_acc), 236 | 'test_acc_top5': float('%.2f' % test_acc_top5), 237 | 'epoch': epoch} 238 | 239 | save_dict_to_json(test_merics, os.path.join(opt.save_folder, "test_best_metrics.json")) 240 | 241 | print('saving the best model!') 242 | torch.save(state, save_file) 243 | 244 | if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0: 245 | # This best accuracy is only for printing purpose. 246 | print('best accuracy:', best_acc) 247 | 248 | # save parameters 249 | state = {k: v for k, v in opt._get_kwargs()} 250 | 251 | # No. parameters(M) 252 | num_params = (sum(p.numel() for p in model.parameters())/1000000.0) 253 | state['Total params'] = num_params 254 | state['Total time'] = float('%.2f' % ((time.time() - total_time) / 3600.0)) 255 | params_json_path = os.path.join(opt.save_folder, "parameters.json") 256 | save_dict_to_json(state, params_json_path) 257 | 258 | if __name__ == '__main__': 259 | main() 260 | --------------------------------------------------------------------------------