├── .DS_Store
├── README.md
├── crd
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-39.pyc
│ ├── criterion.cpython-36.pyc
│ ├── criterion.cpython-39.pyc
│ ├── memory.cpython-36.pyc
│ └── memory.cpython-39.pyc
├── criterion.py
└── memory.py
├── dataset
├── __pycache__
│ ├── cifar100.cpython-36.pyc
│ ├── cifar100.cpython-39.pyc
│ ├── imagenet.cpython-36.pyc
│ ├── imagenet.cpython-39.pyc
│ ├── imagenet_dali.cpython-36.pyc
│ └── imagenet_dali.cpython-39.pyc
├── base.py
├── cifar100.py
├── imagenet.py
└── imagenet_dali.py
├── distiller_zoo
├── AT.py
├── FitNet.py
├── KD.py
├── SP.py
├── SemCKD.py
├── VID.py
├── __init__.py
└── __pycache__
│ ├── AT.cpython-36.pyc
│ ├── AT.cpython-39.pyc
│ ├── FitNet.cpython-36.pyc
│ ├── FitNet.cpython-39.pyc
│ ├── KD.cpython-36.pyc
│ ├── KD.cpython-39.pyc
│ ├── SP.cpython-36.pyc
│ ├── SP.cpython-39.pyc
│ ├── SemCKD.cpython-36.pyc
│ ├── SemCKD.cpython-39.pyc
│ ├── VID.cpython-36.pyc
│ ├── VID.cpython-39.pyc
│ ├── __init__.cpython-36.pyc
│ └── __init__.cpython-39.pyc
├── helper
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-39.pyc
│ ├── loops.cpython-36.pyc
│ ├── loops.cpython-39.pyc
│ ├── util.cpython-36.pyc
│ └── util.cpython-39.pyc
├── loops.py
└── util.py
├── images
├── .DS_Store
├── SimKD_result.png
└── cifar100_result.png
├── models
├── ShuffleNetv1.py
├── ShuffleNetv2.py
├── __init__.py
├── __pycache__
│ ├── ShuffleNetv1.cpython-36.pyc
│ ├── ShuffleNetv1.cpython-39.pyc
│ ├── ShuffleNetv2.cpython-36.pyc
│ ├── ShuffleNetv2.cpython-39.pyc
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-39.pyc
│ ├── mobilenetv2.cpython-36.pyc
│ ├── mobilenetv2.cpython-39.pyc
│ ├── mobilenetv2_imagenet.cpython-36.pyc
│ ├── mobilenetv2_imagenet.cpython-39.pyc
│ ├── resnet.cpython-36.pyc
│ ├── resnet.cpython-39.pyc
│ ├── resnet_imagenet.cpython-36.pyc
│ ├── resnet_imagenet.cpython-39.pyc
│ ├── shuffleNetv2_imagenet.cpython-36.pyc
│ ├── shuffleNetv2_imagenet.cpython-39.pyc
│ ├── util.cpython-36.pyc
│ ├── util.cpython-39.pyc
│ ├── vgg.cpython-36.pyc
│ └── vgg.cpython-39.pyc
├── mobilenetv2.py
├── mobilenetv2_imagenet.py
├── resnet.py
├── resnet_imagenet.py
├── shuffleNetv2_imagenet.py
├── util.py
└── vgg.py
├── scripts
├── run_distill.sh
└── run_vanilla.sh
├── train_student.py
└── train_teacher.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/.DS_Store
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SimKD
2 |
3 | Knowledge Distillation with the Reused Teacher Classifier (CVPR-2022) https://arxiv.org/abs/2203.14001
4 |
5 | # Toolbox for KD research
6 |
7 | This repository aims to provide a compact and easy-to-use implementation of several representative knowledge distillation approaches on standard image classification tasks (e.g., CIFAR100, ImageNet).
8 |
9 | - Generally, these KD approaches include a **classification loss**, a **logit-level distillation loss**, and an additional **feature distillation loss**. For fair comparison and ease of tuning, *we fix the hyper-parameters for the first two loss terms as **one** throughout all experiments.* (`--cls 1 --div 1`)
10 |
11 | - The following approaches are currently supported by this toolbox, covering vanilla KD, feature-map distillation/feature-embedding distillation, instance-level distillation/pairwise-level distillation:
12 | - [x] [Vanilla KD](https://arxiv.org/abs/1503.02531), [FitNet](https://arxiv.org/abs/1412.6550) [ICLR-2015], [AT](https://arxiv.org/abs/1612.03928) [ICLR-2017], [SP](https://arxiv.org/abs/1612.03928) [CVPR-2019], [VID](https://openaccess.thecvf.com/content_CVPR_2019/papers/Ahn_Variational_Information_Distillation_for_Knowledge_Transfer_CVPR_2019_paper.pdf) [CVPR-2019]
13 | - [x] [CRD](https://arxiv.org/abs/1910.10699) [ICLR-2020], [SRRL](https://openreview.net/forum?id=ZzwDy_wiWv) [ICLR-2021], [SemCKD](https://arxiv.org/abs/2012.03236) [AAAI-2021]
14 | - [ ] KR [CVPR-2021]
15 | - [x] [SimKD](https://arxiv.org/abs/2203.14001) [CVPR-2022]
16 |
17 | - This toolbox is built on a [open-source benchmark](https://github.com/HobbitLong/RepDistiller) and our [previous repository](https://github.com/DefangChen/SemCKD). The implementation of more KD approaches can be found there.
18 |
19 | - Computing Infrastructure:
20 | - We use one NVIDIA GeForce RTX 2080Ti GPU for CIFAR-100 experiments. The PyTorch version is 1.0. We use four NVIDIA A40 GPUs for ImageNet experiments. The PyTorch version is 1.10.
21 | - As for ImageNet, we use [DALI](https://github.com/NVIDIA/DALI) for data loading and pre-processing.
22 |
23 | - The current codes have been reorganized and we have not tested them thoroughly. If you have any questions, please contact us without hesitation.
24 |
25 | - Please put the CIFAR-100 and ImageNet dataset in the `../data/`.
26 |
27 | ## Get the pretrained teacher models
28 |
29 | ```bash
30 | # CIFAR-100
31 | python train_teacher.py --batch_size 64 --epochs 240 --dataset cifar100 --model resnet32x4 --learning_rate 0.05 --lr_decay_epochs 150,180,210 --weight_decay 5e-4 --trial 0 --gpu_id 0
32 |
33 | # ImageNet
34 | python train_teacher.py --batch_size 256 --epochs 120 --dataset imagenet --model ResNet18 --learning_rate 0.1 --lr_decay_epochs 30,60,90 --weight_decay 1e-4 --num_workers 32 --gpu_id 0,1,2,3 --dist-url tcp://127.0.0.1:23333 --multiprocessing-distributed --dali gpu --trial 0
35 | ```
36 |
37 | The pretrained teacher models used in our paper are provided in this link [[GoogleDrive]](https://drive.google.com/drive/folders/1j7b8TmftKIRC7ChUwAqVWPIocSiacvP4?usp=sharing).
38 |
39 | ## Train the student models with various KD approaches
40 |
41 | ```bash
42 | # CIFAR-100
43 | python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill simkd --model_s resnet8x4 -c 0 -d 0 -b 1 --trial 0
44 |
45 | # ImageNet
46 | python train_student.py --path-t './save/teachers/models/ResNet50_vanilla/ResNet50_best.pth' --batch_size 256 --epochs 120 --dataset imagenet --model_s ResNet18 --distill simkd -c 0 -d 0 -b 1 --learning_rate 0.1 --lr_decay_epochs 30,60,90 --weight_decay 1e-4 --num_workers 32 --gpu_id 0,1,2,3 --dist-url tcp://127.0.0.1:23444 --multiprocessing-distributed --dali gpu --trial 0
47 | ```
48 | More scripts are provided in `./scripts`
49 |
50 | ## Some results on CIFAR-100
51 |
52 | | | ResNet-8x4 | VGG-8 | ShuffleNetV2x1.5 |
53 | | ----- | ---- | ---- | ---- |
54 | | **Student**| 73.09 | 70.46 | 74.15 |
55 | | KD | 74.42 | 72.73 | 76.82 |
56 | | FitNet | 74.32 | 72.91 | 77.12 |
57 | | AT | 75.07 | 71.90 | 77.51 |
58 | | SP | 74.29 | 73.12 | 77.18 |
59 | | VID | 74.55 | 73.19 | 77.11 |
60 | | CRD | 75.59 | 73.54 | 77.66 |
61 | | SRRL | 75.39 | 73.23 | 77.55 |
62 | | SemCKD | 76.23 | 75.27 | 79.13 |
63 | | SimKD (f=8) | **76.73** | 74.74 | 78.96|
64 | | SimKD (f=4) | **77.88** | **75.62** | **79.48**|
65 | | SimKD (f=2) | **78.08** | **75.76** | **79.54** |
66 | | **Teacher (ResNet-32x4)** | 79.42 | 79.42 | 79.42 |
67 |
68 |
69 | 
70 |
(Left) The cross-entropy loss between model predictions and test labels.
71 | (Right) The top-1 test accuracy (%) (Student: ResNet-8x4, Teacher: ResNet-32x4).
72 |
73 |
74 | ## Citation
75 | If you find this repository useful, please consider citing the following paper:
76 |
77 |
78 | ```
79 | @inproceedings{chen2022simkd,
80 | title={Knowledge Distillation with the Reused Teacher Classifier},
81 | author={Chen, Defang and Mei, Jian-Ping and Zhang, Hailin and Wang, Can and Feng, Yan and Chen, Chun},
82 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
83 | pages={11933--11942},
84 | year={2022}
85 | }
86 | ```
87 | ```
88 | @inproceedings{chen2021cross,
89 | author = {Defang Chen and Jian{-}Ping Mei and Yuan Zhang and Can Wang and Zhe Wang and Yan Feng and Chun Chen},
90 | title = {Cross-Layer Distillation with Semantic Calibration},
91 | booktitle = {Proceedings of the AAAI Conference on Artificial Intelligence},
92 | pages = {7028--7036},
93 | year = {2021},
94 | }
95 | ```
96 |
97 |
--------------------------------------------------------------------------------
/crd/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/crd/__init__.py
--------------------------------------------------------------------------------
/crd/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/crd/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/crd/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/crd/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/crd/__pycache__/criterion.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/crd/__pycache__/criterion.cpython-36.pyc
--------------------------------------------------------------------------------
/crd/__pycache__/criterion.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/crd/__pycache__/criterion.cpython-39.pyc
--------------------------------------------------------------------------------
/crd/__pycache__/memory.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/crd/__pycache__/memory.cpython-36.pyc
--------------------------------------------------------------------------------
/crd/__pycache__/memory.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/crd/__pycache__/memory.cpython-39.pyc
--------------------------------------------------------------------------------
/crd/criterion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from .memory import ContrastMemory
4 |
5 | eps = 1e-7
6 |
7 |
8 | class CRDLoss(nn.Module):
9 | """CRD Loss function
10 | includes two symmetric parts:
11 | (a) using teacher as anchor, choose positive and negatives over the student side
12 | (b) using student as anchor, choose positive and negatives over the teacher side
13 | Args:
14 | opt.s_dim: the dimension of student's feature
15 | opt.t_dim: the dimension of teacher's feature
16 | opt.feat_dim: the dimension of the projection space
17 | opt.nce_k: number of negatives paired with each positive
18 | opt.nce_t: the temperature
19 | opt.nce_m: the momentum for updating the memory buffer
20 | opt.n_data: the number of samples in the training set, therefor the memory buffer is: opt.n_data x opt.feat_dim
21 | """
22 | def __init__(self, opt):
23 | super(CRDLoss, self).__init__()
24 | self.embed_s = Embed(opt.s_dim, opt.feat_dim)
25 | self.embed_t = Embed(opt.t_dim, opt.feat_dim)
26 | self.contrast = ContrastMemory(opt.feat_dim, opt.n_data, opt.nce_k, opt.nce_t, opt.nce_m)
27 | self.criterion_t = ContrastLoss(opt.n_data)
28 | self.criterion_s = ContrastLoss(opt.n_data)
29 |
30 | def forward(self, f_s, f_t, idx, contrast_idx=None):
31 | """
32 | Args:
33 | f_s: the feature of student network, size [batch_size, s_dim]
34 | f_t: the feature of teacher network, size [batch_size, t_dim]
35 | idx: the indices of these positive samples in the dataset, size [batch_size]
36 | contrast_idx: the indices of negative samples, size [batch_size, nce_k]
37 | Returns:
38 | The contrastive loss
39 | """
40 | f_s = self.embed_s(f_s)
41 | f_t = self.embed_t(f_t)
42 | out_s, out_t = self.contrast(f_s, f_t, idx, contrast_idx)
43 | s_loss = self.criterion_s(out_s)
44 | t_loss = self.criterion_t(out_t)
45 | loss = s_loss + t_loss
46 | return loss
47 |
48 |
49 | class ContrastLoss(nn.Module):
50 | """
51 | contrastive loss, corresponding to Eq (18)
52 | """
53 | def __init__(self, n_data):
54 | super(ContrastLoss, self).__init__()
55 | self.n_data = n_data
56 |
57 | def forward(self, x):
58 | bsz = x.shape[0]
59 | m = x.size(1) - 1
60 |
61 | # noise distribution
62 | Pn = 1 / float(self.n_data)
63 |
64 | # loss for positive pair
65 | P_pos = x.select(1, 0)
66 | log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_()
67 |
68 | # loss for K negative pair
69 | P_neg = x.narrow(1, 1, m)
70 | log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_()
71 |
72 | loss = - (log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz
73 |
74 | return loss
75 |
76 |
77 | class Embed(nn.Module):
78 | """Embedding module"""
79 | def __init__(self, dim_in=1024, dim_out=128):
80 | super(Embed, self).__init__()
81 | self.linear = nn.Linear(dim_in, dim_out)
82 | self.l2norm = Normalize(2)
83 |
84 | def forward(self, x):
85 | x = x.view(x.shape[0], -1)
86 | x = self.linear(x)
87 | x = self.l2norm(x)
88 | return x
89 |
90 |
91 | class Normalize(nn.Module):
92 | """normalization layer"""
93 | def __init__(self, power=2):
94 | super(Normalize, self).__init__()
95 | self.power = power
96 |
97 | def forward(self, x):
98 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
99 | out = x.div(norm)
100 | return out
101 |
--------------------------------------------------------------------------------
/crd/memory.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import math
4 |
5 |
6 | class ContrastMemory(nn.Module):
7 | """
8 | memory buffer that supplies large amount of negative samples.
9 | """
10 | def __init__(self, inputSize, outputSize, K, T=0.07, momentum=0.5):
11 | super(ContrastMemory, self).__init__()
12 | self.nLem = outputSize
13 | self.unigrams = torch.ones(self.nLem)
14 | self.multinomial = AliasMethod(self.unigrams)
15 | self.multinomial.cuda()
16 | self.K = K
17 |
18 | self.register_buffer('params', torch.tensor([K, T, -1, -1, momentum]))
19 | stdv = 1. / math.sqrt(inputSize / 3)
20 | self.register_buffer('memory_v1', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv))
21 | self.register_buffer('memory_v2', torch.rand(outputSize, inputSize).mul_(2 * stdv).add_(-stdv))
22 |
23 | def forward(self, v1, v2, y, idx=None):
24 | K = int(self.params[0].item())
25 | T = self.params[1].item()
26 | Z_v1 = self.params[2].item()
27 | Z_v2 = self.params[3].item()
28 |
29 | momentum = self.params[4].item()
30 | batchSize = v1.size(0)
31 | outputSize = self.memory_v1.size(0)
32 | inputSize = self.memory_v1.size(1)
33 |
34 | # original score computation
35 | if idx is None:
36 | idx = self.multinomial.draw(batchSize * (self.K + 1)).view(batchSize, -1)
37 | idx.select(1, 0).copy_(y.data)
38 | # sample
39 | weight_v1 = torch.index_select(self.memory_v1, 0, idx.view(-1)).detach()
40 | weight_v1 = weight_v1.view(batchSize, K + 1, inputSize)
41 | out_v2 = torch.bmm(weight_v1, v2.view(batchSize, inputSize, 1))
42 | out_v2 = torch.exp(torch.div(out_v2, T))
43 | # sample
44 | weight_v2 = torch.index_select(self.memory_v2, 0, idx.view(-1)).detach()
45 | weight_v2 = weight_v2.view(batchSize, K + 1, inputSize)
46 | out_v1 = torch.bmm(weight_v2, v1.view(batchSize, inputSize, 1))
47 | out_v1 = torch.exp(torch.div(out_v1, T))
48 |
49 | # set Z if haven't been set yet
50 | if Z_v1 < 0:
51 | self.params[2] = out_v1.mean() * outputSize
52 | Z_v1 = self.params[2].clone().detach().item()
53 | print("normalization constant Z_v1 is set to {:.1f}".format(Z_v1))
54 | if Z_v2 < 0:
55 | self.params[3] = out_v2.mean() * outputSize
56 | Z_v2 = self.params[3].clone().detach().item()
57 | print("normalization constant Z_v2 is set to {:.1f}".format(Z_v2))
58 |
59 | # compute out_v1, out_v2
60 | out_v1 = torch.div(out_v1, Z_v1).contiguous()
61 | out_v2 = torch.div(out_v2, Z_v2).contiguous()
62 |
63 | # update memory
64 | with torch.no_grad():
65 | l_pos = torch.index_select(self.memory_v1, 0, y.view(-1))
66 | l_pos.mul_(momentum)
67 | l_pos.add_(torch.mul(v1, 1 - momentum))
68 | l_norm = l_pos.pow(2).sum(1, keepdim=True).pow(0.5)
69 | updated_v1 = l_pos.div(l_norm)
70 | self.memory_v1.index_copy_(0, y, updated_v1)
71 |
72 | ab_pos = torch.index_select(self.memory_v2, 0, y.view(-1))
73 | ab_pos.mul_(momentum)
74 | ab_pos.add_(torch.mul(v2, 1 - momentum))
75 | ab_norm = ab_pos.pow(2).sum(1, keepdim=True).pow(0.5)
76 | updated_v2 = ab_pos.div(ab_norm)
77 | self.memory_v2.index_copy_(0, y, updated_v2)
78 |
79 | return out_v1, out_v2
80 |
81 |
82 | class AliasMethod(object):
83 | """
84 | From: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/
85 | """
86 | def __init__(self, probs):
87 |
88 | if probs.sum() > 1:
89 | probs.div_(probs.sum())
90 | K = len(probs)
91 | self.prob = torch.zeros(K)
92 | self.alias = torch.LongTensor([0]*K)
93 |
94 | # Sort the data into the outcomes with probabilities
95 | # that are larger and smaller than 1/K.
96 | smaller = []
97 | larger = []
98 | for kk, prob in enumerate(probs):
99 | self.prob[kk] = K*prob
100 | if self.prob[kk] < 1.0:
101 | smaller.append(kk)
102 | else:
103 | larger.append(kk)
104 |
105 | # Loop though and create little binary mixtures that
106 | # appropriately allocate the larger outcomes over the
107 | # overall uniform mixture.
108 | while len(smaller) > 0 and len(larger) > 0:
109 | small = smaller.pop()
110 | large = larger.pop()
111 |
112 | self.alias[small] = large
113 | self.prob[large] = (self.prob[large] - 1.0) + self.prob[small]
114 |
115 | if self.prob[large] < 1.0:
116 | smaller.append(large)
117 | else:
118 | larger.append(large)
119 |
120 | for last_one in smaller+larger:
121 | self.prob[last_one] = 1
122 |
123 | def cuda(self):
124 | self.prob = self.prob.cuda()
125 | self.alias = self.alias.cuda()
126 |
127 | def draw(self, N):
128 | """ Draw N samples from multinomial """
129 | K = self.alias.size(0)
130 |
131 | kk = torch.zeros(N, dtype=torch.long, device=self.prob.device).random_(0, K)
132 | prob = self.prob.index_select(0, kk)
133 | alias = self.alias.index_select(0, kk)
134 | # b is whether a random number is greater than q
135 | b = torch.bernoulli(prob)
136 | oq = kk.mul(b.long())
137 | oj = alias.mul((1-b).long())
138 |
139 | return oq + oj
--------------------------------------------------------------------------------
/dataset/__pycache__/cifar100.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/dataset/__pycache__/cifar100.cpython-36.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/cifar100.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/dataset/__pycache__/cifar100.cpython-39.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/imagenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/dataset/__pycache__/imagenet.cpython-36.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/imagenet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/dataset/__pycache__/imagenet.cpython-39.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/imagenet_dali.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/dataset/__pycache__/imagenet_dali.cpython-36.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/imagenet_dali.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/dataset/__pycache__/imagenet_dali.cpython-39.pyc
--------------------------------------------------------------------------------
/dataset/base.py:
--------------------------------------------------------------------------------
1 | # https://github.com/tanglang96/DataLoaders_DALI/blob/master/base.py
2 |
3 | from nvidia.dali.plugin.pytorch import DALIGenericIterator
4 |
5 | class DALIDataloader(DALIGenericIterator):
6 | def __init__(self, pipeline, size, batch_size, output_map=["data", "label"], auto_reset=True, onehot_label=False):
7 | self.size = size
8 | self.batch_size = batch_size
9 | self.onehot_label = onehot_label
10 | self.output_map = output_map
11 | super().__init__(pipelines=pipeline, size=size, auto_reset=auto_reset, output_map=output_map)
12 |
13 | def __next__(self):
14 | if self._first_batch is not None:
15 | batch = self._first_batch
16 | self._first_batch = None
17 | return batch
18 | data = super().__next__()[0]
19 | if self.onehot_label:
20 | return [data[self.output_map[0]], data[self.output_map[1]].squeeze().long()]
21 | else:
22 | return [data[self.output_map[0]], data[self.output_map[1]]]
23 |
24 | def __len__(self):
25 | if self.size%self.batch_size==0:
26 | return self.size//self.batch_size
27 | else:
28 | return self.size//self.batch_size+1
29 |
--------------------------------------------------------------------------------
/dataset/cifar100.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import numpy as np
5 | from torch.utils.data import DataLoader
6 | from torchvision import datasets, transforms
7 | from PIL import Image
8 |
9 | """
10 | mean = {
11 | 'cifar100': (0.5071, 0.4867, 0.4408),
12 | }
13 |
14 | std = {
15 | 'cifar100': (0.2675, 0.2565, 0.2761),
16 | }
17 | """
18 |
19 |
20 | def get_data_folder():
21 | """
22 | return the path to store the data
23 | """
24 | data_folder = '../data/'
25 |
26 | if not os.path.isdir(data_folder):
27 | os.makedirs(data_folder)
28 |
29 | return data_folder
30 |
31 | class CIFAR100BackCompat(datasets.CIFAR100):
32 | """
33 | CIFAR100Instance+Sample Dataset
34 | """
35 |
36 | @property
37 | def train_labels(self):
38 | return self.targets
39 |
40 | @property
41 | def test_labels(self):
42 | return self.targets
43 |
44 | @property
45 | def train_data(self):
46 | return self.data
47 |
48 | @property
49 | def test_data(self):
50 | return self.data
51 |
52 | class CIFAR100Instance(CIFAR100BackCompat):
53 | """CIFAR100Instance Dataset.
54 | """
55 | def __getitem__(self, index):
56 |
57 | img, target = self.data[index], self.targets[index]
58 |
59 | # doing this so that it is consistent with all other datasets
60 | # to return a PIL Image
61 | img = Image.fromarray(img)
62 |
63 | if self.transform is not None:
64 | img = self.transform(img)
65 |
66 | if self.target_transform is not None:
67 | target = self.target_transform(target)
68 |
69 | return img, target, index
70 |
71 |
72 | def get_cifar100_dataloaders(batch_size=128, num_workers=8, is_instance=False):
73 | """
74 | cifar 100
75 | """
76 | data_folder = get_data_folder()
77 |
78 | train_transform = transforms.Compose([
79 | transforms.RandomCrop(32, padding=4),
80 | transforms.RandomHorizontalFlip(),
81 | transforms.ToTensor(),
82 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
83 | ])
84 | test_transform = transforms.Compose([
85 | transforms.ToTensor(),
86 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
87 | ])
88 |
89 | if is_instance:
90 | train_set = CIFAR100Instance(root=data_folder,
91 | download=True,
92 | train=True,
93 | transform=train_transform)
94 | n_data = len(train_set)
95 | else:
96 | train_set = datasets.CIFAR100(root=data_folder,
97 | download=True,
98 | train=True,
99 | transform=train_transform)
100 | train_loader = DataLoader(train_set,
101 | batch_size=batch_size,
102 | shuffle=True,
103 | num_workers=num_workers)
104 |
105 | test_set = datasets.CIFAR100(root=data_folder,
106 | download=True,
107 | train=False,
108 | transform=test_transform)
109 | test_loader = DataLoader(test_set,
110 | batch_size=int(batch_size/2),
111 | shuffle=False,
112 | num_workers=int(num_workers/2))
113 |
114 | if is_instance:
115 | return train_loader, test_loader, n_data
116 | else:
117 | return train_loader, test_loader
118 |
119 |
120 | class CIFAR100InstanceSample(CIFAR100BackCompat):
121 | """
122 | CIFAR100Instance+Sample Dataset
123 | """
124 | def __init__(self, root, train=True,
125 | transform=None, target_transform=None,
126 | download=False, k=4096, mode='exact', is_sample=True, percent=1.0):
127 | super().__init__(root=root, train=train, download=download,
128 | transform=transform, target_transform=target_transform)
129 | self.k = k
130 | self.mode = mode
131 | self.is_sample = is_sample
132 |
133 | num_classes = 100
134 | num_samples = len(self.data)
135 | label = self.targets
136 |
137 | self.cls_positive = [[] for i in range(num_classes)]
138 | for i in range(num_samples):
139 | self.cls_positive[label[i]].append(i)
140 |
141 | self.cls_negative = [[] for i in range(num_classes)]
142 | for i in range(num_classes):
143 | for j in range(num_classes):
144 | if j == i:
145 | continue
146 | self.cls_negative[i].extend(self.cls_positive[j])
147 |
148 | self.cls_positive = [np.asarray(self.cls_positive[i]) for i in range(num_classes)]
149 | self.cls_negative = [np.asarray(self.cls_negative[i]) for i in range(num_classes)]
150 |
151 | if 0 < percent < 1:
152 | n = int(len(self.cls_negative[0]) * percent)
153 | self.cls_negative = [np.random.permutation(self.cls_negative[i])[0:n]
154 | for i in range(num_classes)]
155 |
156 | self.cls_positive = np.asarray(self.cls_positive)
157 | self.cls_negative = np.asarray(self.cls_negative)
158 |
159 | def __getitem__(self, index):
160 |
161 | img, target = self.data[index], self.targets[index]
162 |
163 | # doing this so that it is consistent with all other datasets
164 | # to return a PIL Image
165 | img = Image.fromarray(img)
166 |
167 | if self.transform is not None:
168 | img = self.transform(img)
169 |
170 | if self.target_transform is not None:
171 | target = self.target_transform(target)
172 |
173 | if not self.is_sample:
174 | # directly return
175 | return img, target, index
176 | else:
177 | # sample contrastive examples
178 | if self.mode == 'exact':
179 | pos_idx = index
180 | elif self.mode == 'relax':
181 | pos_idx = np.random.choice(self.cls_positive[target], 1)
182 | pos_idx = pos_idx[0]
183 | else:
184 | raise NotImplementedError(self.mode)
185 | replace = True if self.k > len(self.cls_negative[target]) else False
186 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=replace)
187 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx))
188 | return img, target, index, sample_idx
189 |
190 | def get_cifar100_dataloaders_sample(batch_size=128, num_workers=8, k=4096, mode='exact',
191 | is_sample=True, percent=1.0):
192 | """
193 | cifar 100
194 | """
195 | data_folder = get_data_folder()
196 |
197 | train_transform = transforms.Compose([
198 | transforms.RandomCrop(32, padding=4),
199 | transforms.RandomHorizontalFlip(),
200 | transforms.ToTensor(),
201 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
202 | ])
203 | test_transform = transforms.Compose([
204 | transforms.ToTensor(),
205 | transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
206 | ])
207 |
208 | train_set = CIFAR100InstanceSample(root=data_folder,
209 | download=True,
210 | train=True,
211 | transform=train_transform,
212 | k=k,
213 | mode=mode,
214 | is_sample=is_sample,
215 | percent=percent)
216 | n_data = len(train_set)
217 | train_loader = DataLoader(train_set,
218 | batch_size=batch_size,
219 | shuffle=True,
220 | num_workers=num_workers)
221 |
222 | test_set = datasets.CIFAR100(root=data_folder,
223 | download=True,
224 | train=False,
225 | transform=test_transform)
226 | test_loader = DataLoader(test_set,
227 | batch_size=int(batch_size/2),
228 | shuffle=False,
229 | num_workers=int(num_workers/2))
230 |
231 | return train_loader, test_loader, n_data
232 |
--------------------------------------------------------------------------------
/dataset/imagenet.py:
--------------------------------------------------------------------------------
1 | """
2 | get data loaders
3 | """
4 | from __future__ import print_function
5 |
6 | import os
7 | import numpy as np
8 | from torch.utils.data import DataLoader
9 | from torch.utils.data.distributed import DistributedSampler
10 | from torchvision import datasets
11 | from torchvision import transforms
12 |
13 | def get_data_folder(dataset='imagenet'):
14 | """
15 | return the path to store the data
16 | """
17 | data_folder = os.path.join('/home/cdf/RepDistiller-master/RepDistiller-master/data', dataset)
18 |
19 | if not os.path.isdir(data_folder):
20 | os.makedirs(data_folder)
21 |
22 | return data_folder
23 |
24 |
25 | class ImageFolderInstance(datasets.ImageFolder):
26 | """: Folder datasets which returns the index of the image as well::
27 | """
28 | def __getitem__(self, index):
29 | """
30 | Args:
31 | index (int): Index
32 | Returns:
33 | tuple: (image, target) where target is class_index of the target class.
34 | """
35 | path, target = self.imgs[index]
36 | img = self.loader(path)
37 | if self.transform is not None:
38 | img = self.transform(img)
39 | if self.target_transform is not None:
40 | target = self.target_transform(target)
41 |
42 | return img, target, index
43 |
44 |
45 | class ImageFolderSample(datasets.ImageFolder):
46 | """: Folder datasets which returns (img, label, index, contrast_index):
47 | """
48 | def __init__(self, root, transform=None, target_transform=None,
49 | is_sample=False, k=4096):
50 | super().__init__(root=root, transform=transform, target_transform=target_transform)
51 |
52 | self.k = k
53 | self.is_sample = is_sample
54 |
55 | print('stage1 finished!')
56 |
57 | if self.is_sample:
58 | num_classes = len(self.classes)
59 | num_samples = len(self.samples)
60 | label = np.zeros(num_samples, dtype=np.int32)
61 | for i in range(num_samples):
62 | path, target = self.imgs[i]
63 | label[i] = target
64 |
65 | self.cls_positive = [[] for i in range(num_classes)]
66 | for i in range(num_samples):
67 | self.cls_positive[label[i]].append(i)
68 |
69 | self.cls_negative = [[] for i in range(num_classes)]
70 | for i in range(num_classes):
71 | for j in range(num_classes):
72 | if j == i:
73 | continue
74 | self.cls_negative[i].extend(self.cls_positive[j])
75 |
76 | self.cls_positive = [np.asarray(self.cls_positive[i], dtype=np.int32) for i in range(num_classes)]
77 | self.cls_negative = [np.asarray(self.cls_negative[i], dtype=np.int32) for i in range(num_classes)]
78 |
79 | print('dataset initialized!')
80 |
81 | def __getitem__(self, index):
82 | """
83 | Args:
84 | index (int): Index
85 | Returns:
86 | tuple: (image, target) where target is class_index of the target class.
87 | """
88 | path, target = self.imgs[index]
89 | img = self.loader(path)
90 | if self.transform is not None:
91 | img = self.transform(img)
92 | if self.target_transform is not None:
93 | target = self.target_transform(target)
94 |
95 | if self.is_sample:
96 | # sample contrastive examples
97 | pos_idx = index
98 | neg_idx = np.random.choice(self.cls_negative[target], self.k, replace=True)
99 | sample_idx = np.hstack((np.asarray([pos_idx]), neg_idx))
100 | # sample_idx = np.hstack((neg_idx))
101 | return img, target, index, sample_idx
102 | else:
103 | return img, target, index
104 |
105 |
106 | def get_test_loader(dataset='imagenet', batch_size=128, num_workers=8):
107 | """get the test data loader"""
108 |
109 | if dataset == 'imagenet':
110 | data_folder = get_data_folder(dataset)
111 | else:
112 | raise NotImplementedError('dataset not supported: {}'.format(dataset))
113 |
114 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
115 | std=[0.229, 0.224, 0.225])
116 | test_transform = transforms.Compose([
117 | transforms.Resize(256),
118 | transforms.CenterCrop(224),
119 | transforms.ToTensor(),
120 | normalize,
121 | ])
122 |
123 | test_folder = os.path.join(data_folder, 'val')
124 | test_set = datasets.ImageFolder(test_folder, transform=test_transform)
125 | test_loader = DataLoader(test_set,
126 | batch_size=batch_size,
127 | shuffle=False,
128 | num_workers=num_workers,
129 | pin_memory=True)
130 |
131 | return test_loader
132 |
133 |
134 | def get_dataloader_sample(dataset='imagenet', batch_size=128, num_workers=8,
135 | is_sample=False, k=4096, multiprocessing_distributed=False):
136 | """Data Loader for ImageNet"""
137 |
138 | if dataset == 'imagenet':
139 | data_folder = get_data_folder(dataset)
140 | else:
141 | raise NotImplementedError('dataset not supported: {}'.format(dataset))
142 |
143 | # add data transform
144 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
145 | std=[0.229, 0.224, 0.225])
146 | train_transform = transforms.Compose([
147 | transforms.RandomResizedCrop(224),
148 | transforms.RandomHorizontalFlip(),
149 | transforms.ToTensor(),
150 | normalize,
151 | ])
152 | test_transform = transforms.Compose([
153 | transforms.Resize(256),
154 | transforms.CenterCrop(224),
155 | transforms.ToTensor(),
156 | normalize,
157 | ])
158 | train_folder = os.path.join(data_folder, 'train')
159 | test_folder = os.path.join(data_folder, 'val')
160 |
161 | train_set = ImageFolderSample(train_folder, transform=train_transform, is_sample=is_sample, k=k)
162 | test_set = datasets.ImageFolder(test_folder, transform=test_transform)
163 |
164 | if multiprocessing_distributed:
165 | train_sampler = DistributedSampler(train_set)
166 | test_sampler = DistributedSampler(test_set, shuffle=False)
167 | else:
168 | train_sampler = None
169 | test_sampler = None
170 |
171 | train_loader = DataLoader(train_set,
172 | batch_size=batch_size,
173 | shuffle=(train_sampler is None),
174 | num_workers=num_workers,
175 | pin_memory=True,
176 | sampler=train_sampler)
177 | test_loader = DataLoader(test_set,
178 | batch_size=batch_size,
179 | shuffle=False,
180 | num_workers=num_workers,
181 | pin_memory=True,
182 | sampler=test_sampler)
183 |
184 | print('num_samples', len(train_set.samples))
185 | print('num_class', len(train_set.classes))
186 |
187 | return train_loader, test_loader, len(train_set), len(train_set.classes), train_sampler
188 |
189 |
190 | def get_imagenet_dataloader(dataset='imagenet', batch_size=128, num_workers=16,
191 | multiprocessing_distributed=False):
192 | """
193 | Data Loader for imagenet
194 | """
195 | if dataset == 'imagenet':
196 | data_folder = get_data_folder(dataset)
197 | else:
198 | raise NotImplementedError('dataset not supported: {}'.format(dataset))
199 |
200 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
201 | std=[0.229, 0.224, 0.225])
202 | train_transform = transforms.Compose([
203 | transforms.RandomResizedCrop(224),
204 | transforms.RandomHorizontalFlip(),
205 | transforms.ToTensor(),
206 | normalize,
207 | ])
208 | test_transform = transforms.Compose([
209 | transforms.Resize(256),
210 | transforms.CenterCrop(224),
211 | transforms.ToTensor(),
212 | normalize,
213 | ])
214 |
215 | train_folder = os.path.join(data_folder, 'train')
216 | test_folder = os.path.join(data_folder, 'val')
217 |
218 | train_set = datasets.ImageFolder(train_folder, transform=train_transform)
219 | test_set = datasets.ImageFolder(test_folder, transform=test_transform)
220 |
221 | if multiprocessing_distributed:
222 | train_sampler = DistributedSampler(train_set)
223 | test_sampler = DistributedSampler(test_set, shuffle=False)
224 | else:
225 | train_sampler = None
226 | test_sampler = None
227 |
228 | train_loader = DataLoader(train_set,
229 | batch_size=batch_size,
230 | shuffle=(train_sampler is None),
231 | num_workers=num_workers,
232 | pin_memory=True,
233 | sampler=train_sampler)
234 |
235 | test_loader = DataLoader(test_set,
236 | batch_size=batch_size,
237 | shuffle=False,
238 | num_workers=num_workers,
239 | pin_memory=True,
240 | sampler=test_sampler)
241 |
242 | return train_loader, test_loader, train_sampler
243 |
--------------------------------------------------------------------------------
/dataset/imagenet_dali.py:
--------------------------------------------------------------------------------
1 | # https://github.com/NVIDIA/DALI/blob/master/docs/examples/use_cases/pytorch/resnet50/main.py
2 |
3 | import argparse
4 | import os
5 | import shutil
6 | import time
7 | import math
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.parallel
12 | import torch.backends.cudnn as cudnn
13 | import torch.distributed as dist
14 | import torch.optim
15 | import torch.utils.data
16 | import torch.utils.data.distributed
17 | import torchvision.transforms as transforms
18 | import torchvision.datasets as datasets
19 | import torchvision.models as models
20 |
21 | import numpy as np
22 |
23 | try:
24 | from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy
25 | from nvidia.dali.pipeline import pipeline_def
26 | import nvidia.dali.types as types
27 | import nvidia.dali.fn as fn
28 | except ImportError:
29 | raise ImportError("Please install DALI from https://www.github.com/NVIDIA/DALI to run this example.")
30 |
31 | @pipeline_def
32 | def create_dali_pipeline(data_dir, crop, size, shard_id, num_shards, dali_cpu=False, is_training=True):
33 | images, labels = fn.readers.file(file_root=data_dir,
34 | shard_id=shard_id,
35 | num_shards=num_shards,
36 | random_shuffle=is_training,
37 | pad_last_batch=True,
38 | name="Reader")
39 | dali_device = 'cpu' if dali_cpu else 'gpu'
40 | decoder_device = 'cpu' if dali_cpu else 'mixed'
41 | device_memory_padding = 211025920 if decoder_device == 'mixed' else 0
42 | host_memory_padding = 140544512 if decoder_device == 'mixed' else 0
43 | if is_training:
44 | images = fn.decoders.image_random_crop(images,
45 | device=decoder_device, output_type=types.RGB,
46 | device_memory_padding=device_memory_padding,
47 | host_memory_padding=host_memory_padding,
48 | random_aspect_ratio=[0.8, 1.25],
49 | random_area=[0.1, 1.0],
50 | num_attempts=100)
51 | images = fn.resize(images,
52 | device=dali_device,
53 | resize_x=crop,
54 | resize_y=crop,
55 | interp_type=types.INTERP_TRIANGULAR)
56 | mirror = fn.random.coin_flip(probability=0.5)
57 | else:
58 | images = fn.decoders.image(images,
59 | device=decoder_device,
60 | output_type=types.RGB)
61 | images = fn.resize(images,
62 | device=dali_device,
63 | size=size,
64 | mode="not_smaller",
65 | interp_type=types.INTERP_TRIANGULAR)
66 | mirror = False
67 |
68 | images = fn.crop_mirror_normalize(images.gpu(),
69 | dtype=types.FLOAT,
70 | output_layout="CHW",
71 | crop=(crop, crop),
72 | mean=[0.485 * 255,0.456 * 255,0.406 * 255],
73 | std=[0.229 * 255,0.224 * 255,0.225 * 255],
74 | mirror=mirror)
75 | labels = labels.gpu()
76 | return images, labels
77 |
78 | def get_dali_data_loader(args):
79 | crop_size = 224
80 | val_size = 256
81 |
82 | path = '../data'
83 | data_folder = os.path.join(path, args.dataset)
84 | if not os.path.isdir(data_folder):
85 | print('Please place the ImageNet dataset at: ', path)
86 |
87 | traindir = os.path.join(data_folder, 'train')
88 | valdir = os.path.join(data_folder, 'val')
89 |
90 | pipe = create_dali_pipeline(batch_size=args.batch_size,
91 | num_threads=args.num_workers,
92 | device_id=args.rank,
93 | seed=12 + args.rank,
94 | data_dir=traindir,
95 | crop=crop_size,
96 | size=val_size,
97 | dali_cpu=args.dali == 'cpu',
98 | shard_id=args.rank,
99 | num_shards=args.world_size,
100 | is_training=True)
101 | pipe.build()
102 | train_loader = DALIClassificationIterator(pipe, reader_name="Reader", last_batch_policy=LastBatchPolicy.PARTIAL)
103 |
104 | pipe = create_dali_pipeline(batch_size=args.batch_size,
105 | num_threads=args.num_workers,
106 | device_id=args.rank,
107 | seed=12 + args.rank,
108 | data_dir=valdir,
109 | crop=crop_size,
110 | size=val_size,
111 | dali_cpu=args.dali == 'cpu',
112 | shard_id=args.rank,
113 | num_shards=args.world_size,
114 | is_training=False)
115 | pipe.build()
116 | val_loader = DALIClassificationIterator(pipe, reader_name="Reader", last_batch_policy=LastBatchPolicy.PARTIAL)
117 |
118 | return train_loader, val_loader
--------------------------------------------------------------------------------
/distiller_zoo/AT.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class Attention(nn.Module):
8 | """Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks
9 | via Attention Transfer
10 | code: https://github.com/szagoruyko/attention-transfer"""
11 | def __init__(self, p=2):
12 | super(Attention, self).__init__()
13 | self.p = p
14 |
15 | def forward(self, g_s, g_t):
16 | # only calculate min(len(g_s), len(g_t))-pair at_loss with the help of zip function
17 | return [self.at_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
18 |
19 | def at_loss(self, f_s, f_t):
20 | s_H, t_H = f_s.shape[2], f_t.shape[2]
21 | if s_H > t_H:
22 | f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))
23 | elif s_H < t_H:
24 | f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))
25 | else:
26 | pass
27 | return (self.at(f_s) - self.at(f_t)).pow(2).mean()
28 |
29 | def at(self, f):
30 | # mean(1) function reduce feature map BxCxHxW into BxHxW by averaging the channel response
31 | return F.normalize(f.pow(self.p).mean(1).view(f.size(0), -1))
32 |
--------------------------------------------------------------------------------
/distiller_zoo/FitNet.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch.nn as nn
4 |
5 |
6 | class HintLoss(nn.Module):
7 | """Fitnets: hints for thin deep nets, ICLR 2015"""
8 | def __init__(self):
9 | super(HintLoss, self).__init__()
10 | self.crit = nn.MSELoss()
11 |
12 | def forward(self, f_s, f_t):
13 | loss = self.crit(f_s, f_t)
14 | return loss
15 |
--------------------------------------------------------------------------------
/distiller_zoo/KD.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class DistillKL(nn.Module):
8 | """Distilling the Knowledge in a Neural Network"""
9 | def __init__(self, T):
10 | super(DistillKL, self).__init__()
11 | self.T = T
12 |
13 | def forward(self, y_s, y_t):
14 | p_s = F.log_softmax(y_s/self.T, dim=1)
15 | p_t = F.softmax(y_t/self.T, dim=1)
16 | loss = nn.KLDivLoss(reduction='batchmean')(p_s, p_t) * (self.T**2)
17 | return loss
18 |
--------------------------------------------------------------------------------
/distiller_zoo/SP.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class Similarity(nn.Module):
9 | """Similarity-Preserving Knowledge Distillation, ICCV2019, verified by original author"""
10 | def __init__(self):
11 | super(Similarity, self).__init__()
12 |
13 | def forward(self, g_s, g_t):
14 | return [self.similarity_loss(f_s, f_t) for f_s, f_t in zip(g_s, g_t)]
15 |
16 | def similarity_loss(self, f_s, f_t):
17 | bsz = f_s.shape[0]
18 | f_s = f_s.view(bsz, -1)
19 | f_t = f_t.view(bsz, -1)
20 |
21 | G_s = torch.mm(f_s, torch.t(f_s))
22 | # G_s = G_s / G_s.norm(2)
23 | G_s = torch.nn.functional.normalize(G_s, dim=1)
24 | G_t = torch.mm(f_t, torch.t(f_t))
25 | # G_t = G_t / G_t.norm(2)
26 | G_t = torch.nn.functional.normalize(G_t, dim=1)
27 |
28 | G_diff = G_t - G_s
29 | loss = (G_diff * G_diff).view(-1, 1).sum(0) / (bsz * bsz)
30 | return loss
31 |
--------------------------------------------------------------------------------
/distiller_zoo/SemCKD.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class SemCKDLoss(nn.Module):
9 | """Cross-Layer Distillation with Semantic Calibration, AAAI2021"""
10 | def __init__(self):
11 | super(SemCKDLoss, self).__init__()
12 | self.crit = nn.MSELoss(reduction='none')
13 |
14 | def forward(self, s_value, f_target, weight):
15 | bsz, num_stu, num_tea = weight.shape
16 | ind_loss = torch.zeros(bsz, num_stu, num_tea).cuda()
17 |
18 | for i in range(num_stu):
19 | for j in range(num_tea):
20 | ind_loss[:, i, j] = self.crit(s_value[i][j], f_target[i][j]).reshape(bsz,-1).mean(-1)
21 |
22 | loss = (weight * ind_loss).sum()/(1.0*bsz*num_stu)
23 | return loss
--------------------------------------------------------------------------------
/distiller_zoo/VID.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import numpy as np
7 |
8 |
9 | class VIDLoss(nn.Module):
10 | """Variational Information Distillation for Knowledge Transfer (CVPR 2019),
11 | code from author: https://github.com/ssahn0215/variational-information-distillation"""
12 | def __init__(self,
13 | num_input_channels,
14 | num_mid_channel,
15 | num_target_channels,
16 | init_pred_var=5.0,
17 | eps=1e-5):
18 | super(VIDLoss, self).__init__()
19 |
20 | def conv1x1(in_channels, out_channels, stride=1):
21 | return nn.Conv2d(
22 | in_channels, out_channels,
23 | kernel_size=1, padding=0,
24 | bias=False, stride=stride)
25 |
26 | self.regressor = nn.Sequential(
27 | conv1x1(num_input_channels, num_mid_channel),
28 | nn.ReLU(),
29 | conv1x1(num_mid_channel, num_mid_channel),
30 | nn.ReLU(),
31 | conv1x1(num_mid_channel, num_target_channels),
32 | )
33 | self.log_scale = torch.nn.Parameter(
34 | np.log(np.exp(init_pred_var-eps)-1.0) * torch.ones(num_target_channels)
35 | )
36 | self.eps = eps
37 |
38 | def forward(self, input, target):
39 | # pool for dimentsion match
40 | s_H, t_H = input.shape[2], target.shape[2]
41 | if s_H > t_H:
42 | input = F.adaptive_avg_pool2d(input, (t_H, t_H))
43 | elif s_H < t_H:
44 | target = F.adaptive_avg_pool2d(target, (s_H, s_H))
45 | else:
46 | pass
47 | pred_mean = self.regressor(input)
48 | pred_var = torch.log(1.0+torch.exp(self.log_scale))+self.eps
49 | pred_var = pred_var.view(1, -1, 1, 1)
50 | neg_log_prob = 0.5*(
51 | (pred_mean-target)**2/pred_var+torch.log(pred_var)
52 | )
53 | loss = torch.mean(neg_log_prob)
54 | return loss
55 |
--------------------------------------------------------------------------------
/distiller_zoo/__init__.py:
--------------------------------------------------------------------------------
1 | from .FitNet import HintLoss
2 | from .AT import Attention
3 | from .KD import DistillKL
4 | from .SP import Similarity
5 | from .VID import VIDLoss
6 | from .SemCKD import SemCKDLoss
7 |
--------------------------------------------------------------------------------
/distiller_zoo/__pycache__/AT.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/AT.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller_zoo/__pycache__/AT.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/AT.cpython-39.pyc
--------------------------------------------------------------------------------
/distiller_zoo/__pycache__/FitNet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/FitNet.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller_zoo/__pycache__/FitNet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/FitNet.cpython-39.pyc
--------------------------------------------------------------------------------
/distiller_zoo/__pycache__/KD.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/KD.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller_zoo/__pycache__/KD.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/KD.cpython-39.pyc
--------------------------------------------------------------------------------
/distiller_zoo/__pycache__/SP.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/SP.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller_zoo/__pycache__/SP.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/SP.cpython-39.pyc
--------------------------------------------------------------------------------
/distiller_zoo/__pycache__/SemCKD.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/SemCKD.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller_zoo/__pycache__/SemCKD.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/SemCKD.cpython-39.pyc
--------------------------------------------------------------------------------
/distiller_zoo/__pycache__/VID.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/VID.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller_zoo/__pycache__/VID.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/VID.cpython-39.pyc
--------------------------------------------------------------------------------
/distiller_zoo/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/distiller_zoo/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/distiller_zoo/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/helper/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/helper/__init__.py
--------------------------------------------------------------------------------
/helper/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/helper/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/helper/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/helper/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/helper/__pycache__/loops.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/helper/__pycache__/loops.cpython-36.pyc
--------------------------------------------------------------------------------
/helper/__pycache__/loops.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/helper/__pycache__/loops.cpython-39.pyc
--------------------------------------------------------------------------------
/helper/__pycache__/util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/helper/__pycache__/util.cpython-36.pyc
--------------------------------------------------------------------------------
/helper/__pycache__/util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/helper/__pycache__/util.cpython-39.pyc
--------------------------------------------------------------------------------
/helper/loops.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | from cProfile import label
3 |
4 | import sys
5 | import time
6 | import torch
7 | from .util import AverageMeter, accuracy, reduce_tensor
8 |
9 | def train_vanilla(epoch, train_loader, model, criterion, optimizer, opt):
10 | """vanilla training"""
11 | model.train()
12 |
13 | batch_time = AverageMeter()
14 | losses = AverageMeter()
15 | top1 = AverageMeter()
16 | top5 = AverageMeter()
17 |
18 | n_batch = len(train_loader) if opt.dali is None else (train_loader._size + opt.batch_size - 1) // opt.batch_size
19 |
20 | end = time.time()
21 | for idx, batch_data in enumerate(train_loader):
22 | if opt.dali is None:
23 | images, labels = batch_data
24 | else:
25 | images, labels = batch_data[0]['data'], batch_data[0]['label'].squeeze().long()
26 |
27 | if opt.gpu is not None:
28 | images = images.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True)
29 | if torch.cuda.is_available():
30 | labels = labels.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True)
31 |
32 | # ===================forward=====================
33 | output = model(images)
34 | loss = criterion(output, labels)
35 | losses.update(loss.item(), images.size(0))
36 |
37 | # ===================Metrics=====================
38 | metrics = accuracy(output, labels, topk=(1, 5))
39 | top1.update(metrics[0].item(), images.size(0))
40 | top5.update(metrics[1].item(), images.size(0))
41 | batch_time.update(time.time() - end)
42 |
43 | # ===================backward=====================
44 | optimizer.zero_grad()
45 | loss.backward()
46 |
47 | optimizer.step()
48 |
49 | # print info
50 | if idx % opt.print_freq == 0:
51 | print('Epoch: [{0}][{1}/{2}]\t'
52 | 'GPU {3}\t'
53 | 'Time: {batch_time.avg:.3f}\t'
54 | 'Loss {loss.avg:.4f}\t'
55 | 'Acc@1 {top1.avg:.3f}\t'
56 | 'Acc@5 {top5.avg:.3f}'.format(
57 | epoch, idx, n_batch, opt.gpu, batch_time=batch_time,
58 | loss=losses, top1=top1, top5=top5))
59 | sys.stdout.flush()
60 |
61 | return top1.avg, top5.avg, losses.avg
62 |
63 | def train_distill(epoch, train_loader, module_list, criterion_list, optimizer, opt):
64 | """one epoch distillation"""
65 | # set modules as train()
66 | for module in module_list:
67 | module.train()
68 | # set teacher as eval()
69 | module_list[-1].eval()
70 |
71 | criterion_cls = criterion_list[0]
72 | criterion_div = criterion_list[1]
73 | criterion_kd = criterion_list[2]
74 |
75 | model_s = module_list[0]
76 | model_t = module_list[-1]
77 |
78 | batch_time = AverageMeter()
79 | losses = AverageMeter()
80 | top1 = AverageMeter()
81 | top5 = AverageMeter()
82 |
83 | n_batch = len(train_loader) if opt.dali is None else (train_loader._size + opt.batch_size - 1) // opt.batch_size
84 |
85 | end = time.time()
86 | for idx, data in enumerate(train_loader):
87 | if opt.dali is None:
88 | if opt.distill in ['crd']:
89 | images, labels, index, contrast_idx = data
90 | else:
91 | images, labels = data
92 | else:
93 | images, labels = data[0]['data'], data[0]['label'].squeeze().long()
94 |
95 | if opt.distill == 'semckd' and images.shape[0] < opt.batch_size:
96 | continue
97 |
98 | if opt.gpu is not None:
99 | images = images.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True)
100 | if torch.cuda.is_available():
101 | labels = labels.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True)
102 | if opt.distill in ['crd']:
103 | index = index.cuda()
104 | contrast_idx = contrast_idx.cuda()
105 |
106 | # ===================forward=====================
107 | feat_s, logit_s = model_s(images, is_feat=True)
108 | with torch.no_grad():
109 | feat_t, logit_t = model_t(images, is_feat=True)
110 | feat_t = [f.detach() for f in feat_t]
111 |
112 | cls_t = model_t.module.get_feat_modules()[-1] if opt.multiprocessing_distributed else model_t.get_feat_modules()[-1]
113 |
114 | # cls + kl div
115 | loss_cls = criterion_cls(logit_s, labels)
116 | loss_div = criterion_div(logit_s, logit_t)
117 |
118 | # other kd loss
119 | if opt.distill == 'kd':
120 | loss_kd = 0
121 | elif opt.distill == 'hint':
122 | f_s, f_t = module_list[1](feat_s[opt.hint_layer], feat_t[opt.hint_layer])
123 | loss_kd = criterion_kd(f_s, f_t)
124 | elif opt.distill == 'attention':
125 | # include 1, exclude -1.
126 | g_s = feat_s[1:-1]
127 | g_t = feat_t[1:-1]
128 | loss_group = criterion_kd(g_s, g_t)
129 | loss_kd = sum(loss_group)
130 | elif opt.distill == 'similarity':
131 | g_s = [feat_s[-2]]
132 | g_t = [feat_t[-2]]
133 | loss_group = criterion_kd(g_s, g_t)
134 | loss_kd = sum(loss_group)
135 | elif opt.distill == 'vid':
136 | g_s = feat_s[1:-1]
137 | g_t = feat_t[1:-1]
138 | loss_group = [c(f_s, f_t) for f_s, f_t, c in zip(g_s, g_t, criterion_kd)]
139 | loss_kd = sum(loss_group)
140 | elif opt.distill == 'crd':
141 | f_s = feat_s[-1]
142 | f_t = feat_t[-1]
143 | loss_kd = criterion_kd(f_s, f_t, index, contrast_idx)
144 | elif opt.distill == 'semckd':
145 | s_value, f_target, weight = module_list[1](feat_s[1:-1], feat_t[1:-1])
146 | loss_kd = criterion_kd(s_value, f_target, weight)
147 | elif opt.distill == 'srrl':
148 | trans_feat_s, pred_feat_s = module_list[1](feat_s[-1], cls_t)
149 | loss_kd = criterion_kd(trans_feat_s, feat_t[-1]) + criterion_kd(pred_feat_s, logit_t)
150 | elif opt.distill == 'simkd':
151 | trans_feat_s, trans_feat_t, pred_feat_s = module_list[1](feat_s[-2], feat_t[-2], cls_t)
152 | logit_s = pred_feat_s
153 | loss_kd = criterion_kd(trans_feat_s, trans_feat_t)
154 | else:
155 | raise NotImplementedError(opt.distill)
156 |
157 | loss = opt.cls * loss_cls + opt.div * loss_div + opt.beta * loss_kd
158 | losses.update(loss.item(), images.size(0))
159 |
160 | # ===================Metrics=====================
161 | metrics = accuracy(logit_s, labels, topk=(1, 5))
162 | top1.update(metrics[0].item(), images.size(0))
163 | top5.update(metrics[1].item(), images.size(0))
164 | batch_time.update(time.time() - end)
165 |
166 | # ===================backward=====================
167 | optimizer.zero_grad()
168 | loss.backward()
169 | optimizer.step()
170 |
171 | # print info
172 | if idx % opt.print_freq == 0:
173 | print('Epoch: [{0}][{1}/{2}]\t'
174 | 'GPU {3}\t'
175 | 'Time: {batch_time.avg:.3f}\t'
176 | 'Loss {loss.avg:.4f}\t'
177 | 'Acc@1 {top1.avg:.3f}\t'
178 | 'Acc@5 {top5.avg:.3f}'.format(
179 | epoch, idx, n_batch, opt.gpu, loss=losses, top1=top1, top5=top5,
180 | batch_time=batch_time))
181 | sys.stdout.flush()
182 |
183 | return top1.avg, top5.avg, losses.avg
184 |
185 | def validate_vanilla(val_loader, model, criterion, opt):
186 | """validation"""
187 |
188 | batch_time = AverageMeter()
189 | losses = AverageMeter()
190 | top1 = AverageMeter()
191 | top5 = AverageMeter()
192 |
193 | # switch to evaluate mode
194 | model.eval()
195 |
196 | n_batch = len(val_loader) if opt.dali is None else (val_loader._size + opt.batch_size - 1) // opt.batch_size
197 |
198 | with torch.no_grad():
199 | end = time.time()
200 | for idx, batch_data in enumerate(val_loader):
201 |
202 | if opt.dali is None:
203 | images, labels = batch_data
204 | else:
205 | images, labels = batch_data[0]['data'], batch_data[0]['label'].squeeze().long()
206 |
207 | if opt.gpu is not None:
208 | images = images.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True)
209 | if torch.cuda.is_available():
210 | labels = labels.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True)
211 |
212 | # compute output
213 | output = model(images)
214 | loss = criterion(output, labels)
215 | losses.update(loss.item(), images.size(0))
216 |
217 | # ===================Metrics=====================
218 | metrics = accuracy(output, labels, topk=(1, 5))
219 | top1.update(metrics[0].item(), images.size(0))
220 | top5.update(metrics[1].item(), images.size(0))
221 | batch_time.update(time.time() - end)
222 |
223 | if idx % opt.print_freq == 0:
224 | print('Test: [{0}/{1}]\t'
225 | 'GPU: {2}\t'
226 | 'Time: {batch_time.avg:.3f}\t'
227 | 'Loss {loss.avg:.4f}\t'
228 | 'Acc@1 {top1.avg:.3f}\t'
229 | 'Acc@5 {top5.avg:.3f}'.format(
230 | idx, n_batch, opt.gpu, batch_time=batch_time, loss=losses,
231 | top1=top1, top5=top5))
232 |
233 | if opt.multiprocessing_distributed:
234 | # Batch size may not be equal across multiple gpus
235 | total_metrics = torch.tensor([top1.sum, top5.sum, losses.sum]).to(opt.gpu)
236 | count_metrics = torch.tensor([top1.count, top5.count, losses.count]).to(opt.gpu)
237 | total_metrics = reduce_tensor(total_metrics, 1) # here world_size=1, because they should be summed up
238 | count_metrics = reduce_tensor(count_metrics, 1)
239 | ret = []
240 | for s, n in zip(total_metrics.tolist(), count_metrics.tolist()):
241 | ret.append(s / (1.0 * n))
242 | return ret
243 |
244 | return top1.avg, top5.avg, losses.avg
245 |
246 |
247 | def validate_distill(val_loader, module_list, criterion, opt):
248 | """validation"""
249 |
250 | batch_time = AverageMeter()
251 | losses = AverageMeter()
252 | top1 = AverageMeter()
253 | top5 = AverageMeter()
254 |
255 | # switch to evaluate mode
256 | for module in module_list:
257 | module.eval()
258 |
259 | model_s = module_list[0]
260 | model_t = module_list[-1]
261 | n_batch = len(val_loader) if opt.dali is None else (val_loader._size + opt.batch_size - 1) // opt.batch_size
262 |
263 | with torch.no_grad():
264 | end = time.time()
265 | for idx, batch_data in enumerate(val_loader):
266 |
267 | if opt.dali is None:
268 | images, labels = batch_data
269 | else:
270 | images, labels = batch_data[0]['data'], batch_data[0]['label'].squeeze().long()
271 |
272 | if opt.gpu is not None:
273 | images = images.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True)
274 | if torch.cuda.is_available():
275 | labels = labels.cuda(opt.gpu if opt.multiprocessing_distributed else 0, non_blocking=True)
276 |
277 | # compute output
278 | if opt.distill == 'simkd':
279 | feat_s, _ = model_s(images, is_feat=True)
280 | feat_t, _ = model_t(images, is_feat=True)
281 | feat_t = [f.detach() for f in feat_t]
282 | cls_t = model_t.module.get_feat_modules()[-1] if opt.multiprocessing_distributed else model_t.get_feat_modules()[-1]
283 | _, _, output = module_list[1](feat_s[-2], feat_t[-2], cls_t)
284 | else:
285 | output = model_s(images)
286 | loss = criterion(output, labels)
287 | losses.update(loss.item(), images.size(0))
288 |
289 | # ===================Metrics=====================
290 | metrics = accuracy(output, labels, topk=(1, 5))
291 | top1.update(metrics[0].item(), images.size(0))
292 | top5.update(metrics[1].item(), images.size(0))
293 | batch_time.update(time.time() - end)
294 |
295 | if idx % opt.print_freq == 0:
296 | print('Test: [{0}/{1}]\t'
297 | 'GPU: {2}\t'
298 | 'Time: {batch_time.avg:.3f}\t'
299 | 'Loss {loss.avg:.4f}\t'
300 | 'Acc@1 {top1.avg:.3f}\t'
301 | 'Acc@5 {top5.avg:.3f}'.format(
302 | idx, n_batch, opt.gpu, batch_time=batch_time, loss=losses,
303 | top1=top1, top5=top5))
304 |
305 | if opt.multiprocessing_distributed:
306 | # Batch size may not be equal across multiple gpus
307 | total_metrics = torch.tensor([top1.sum, top5.sum, losses.sum]).to(opt.gpu)
308 | count_metrics = torch.tensor([top1.count, top5.count, losses.count]).to(opt.gpu)
309 | total_metrics = reduce_tensor(total_metrics, 1) # here world_size=1, because they should be summed up
310 | count_metrics = reduce_tensor(count_metrics, 1)
311 | ret = []
312 | for s, n in zip(total_metrics.tolist(), count_metrics.tolist()):
313 | ret.append(s / (1.0 * n))
314 | return ret
315 |
316 | return top1.avg, top5.avg, losses.avg
317 |
--------------------------------------------------------------------------------
/helper/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import json
4 | import torch
5 | import numpy as np
6 | import torch.distributed as dist
7 |
8 | def adjust_learning_rate(epoch, opt, optimizer):
9 | """Sets the learning rate to the initial LR decayed by decay rate every steep step"""
10 | steps = np.sum(epoch > np.asarray(opt.lr_decay_epochs))
11 | if steps > 0:
12 | new_lr = opt.learning_rate * (opt.lr_decay_rate ** steps)
13 | for param_group in optimizer.param_groups:
14 | param_group['lr'] = new_lr
15 |
16 | # def adjust_learning_rate(optimizer, epoch, step, len_epoch, old_lr):
17 | # """Sets the learning rate to the initial LR decayed by decay rate every steep step"""
18 | # if epoch < 5:
19 | # lr = old_lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch)
20 | # elif 5 <= epoch < 60: return
21 | # else:
22 | # factor = epoch // 30
23 | # factor -= 1
24 | # lr = old_lr*(0.1**factor)
25 |
26 | # for param_group in optimizer.param_groups:
27 | # param_group['lr'] = lr
28 |
29 |
30 | class AverageMeter(object):
31 | """Computes and stores the average and current value"""
32 | def __init__(self):
33 | self.reset()
34 |
35 | def reset(self):
36 | self.val = 0
37 | self.avg = 0
38 | self.sum = 0
39 | self.count = 0
40 |
41 | def update(self, val, n=1):
42 | self.val = val
43 | self.sum += val * n
44 | self.count += n
45 | self.avg = self.sum / self.count
46 |
47 | def accuracy(output, target, topk=(1,)):
48 | """Computes the accuracy over the k top predictions for the specified values of k"""
49 | with torch.no_grad():
50 | maxk = max(topk)
51 | batch_size = target.size(0)
52 |
53 | _, pred = output.topk(maxk, dim = 1, largest = True, sorted = True)
54 | pred = pred.t()
55 | correct = pred.eq(target.view(1, -1).expand_as(pred))
56 |
57 | res = []
58 | for k in topk:
59 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
60 | res.append(correct_k.mul_(100.0 / batch_size))
61 | return res
62 |
63 | def save_dict_to_json(d, json_path):
64 | """Saves dict of floats in json file
65 |
66 | Args:
67 | d: (dict) of float-castable values (np.float, int, float, etc.)
68 | json_path: (string) path to json file
69 | """
70 | with open(json_path, 'w') as f:
71 | # We need to convert the values to float for json (it doesn't accept np.array, np.float, )
72 | d = {k: v for k, v in d.items()}
73 | json.dump(d, f, indent=4)
74 |
75 | def load_json_to_dict(json_path):
76 | """Loads json file to dict
77 |
78 | Args:
79 | json_path: (string) path to json file
80 | """
81 | with open(json_path, 'r') as f:
82 | params = json.load(f)
83 | return params
84 |
85 | def reduce_tensor(tensor, world_size = 1, op='avg'):
86 | rt = tensor.clone()
87 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
88 | if world_size > 1:
89 | rt = torch.true_divide(rt, world_size)
90 | return rt
91 |
92 | if __name__ == '__main__':
93 |
94 | pass
95 |
--------------------------------------------------------------------------------
/images/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/images/.DS_Store
--------------------------------------------------------------------------------
/images/SimKD_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/images/SimKD_result.png
--------------------------------------------------------------------------------
/images/cifar100_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/images/cifar100_result.png
--------------------------------------------------------------------------------
/models/ShuffleNetv1.py:
--------------------------------------------------------------------------------
1 | '''ShuffleNet in PyTorch.
2 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details.
3 | '''
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class ShuffleBlock(nn.Module):
10 | def __init__(self, groups):
11 | super(ShuffleBlock, self).__init__()
12 | self.groups = groups
13 |
14 | def forward(self, x):
15 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
16 | N,C,H,W = x.size()
17 | g = self.groups
18 | return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).reshape(N,C,H,W)
19 |
20 |
21 | class Bottleneck(nn.Module):
22 | def __init__(self, in_planes, out_planes, stride, groups, is_last=False):
23 | super(Bottleneck, self).__init__()
24 | self.is_last = is_last
25 | self.stride = stride
26 |
27 | mid_planes = int(out_planes/4)
28 | g = 1 if in_planes == 24 else groups
29 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False)
30 | self.bn1 = nn.BatchNorm2d(mid_planes)
31 | self.shuffle1 = ShuffleBlock(groups=g)
32 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False)
33 | self.bn2 = nn.BatchNorm2d(mid_planes)
34 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False)
35 | self.bn3 = nn.BatchNorm2d(out_planes)
36 |
37 | self.shortcut = nn.Sequential()
38 | if stride == 2:
39 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1))
40 |
41 | def forward(self, x):
42 | out = F.relu(self.bn1(self.conv1(x)))
43 | out = self.shuffle1(out)
44 | out = F.relu(self.bn2(self.conv2(out)))
45 | out = self.bn3(self.conv3(out))
46 | res = self.shortcut(x)
47 | preact = torch.cat([out, res], 1) if self.stride == 2 else out+res
48 | out = F.relu(preact)
49 | # out = F.relu(torch.cat([out, res], 1)) if self.stride == 2 else F.relu(out+res)
50 | if self.is_last:
51 | return out, preact
52 | else:
53 | return out
54 |
55 |
56 | class ShuffleNet(nn.Module):
57 | def __init__(self, cfg, num_classes=10):
58 | super(ShuffleNet, self).__init__()
59 | out_planes = cfg['out_planes']
60 | num_blocks = cfg['num_blocks']
61 | groups = cfg['groups']
62 |
63 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False)
64 | self.bn1 = nn.BatchNorm2d(24)
65 | self.in_planes = 24
66 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups)
67 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups)
68 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups)
69 | self.linear = nn.Linear(out_planes[2], num_classes)
70 |
71 | def _make_layer(self, out_planes, num_blocks, groups):
72 | layers = []
73 | for i in range(num_blocks):
74 | stride = 2 if i == 0 else 1
75 | cat_planes = self.in_planes if i == 0 else 0
76 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes,
77 | stride=stride,
78 | groups=groups,
79 | is_last=(i == num_blocks - 1)))
80 | self.in_planes = out_planes
81 | return nn.Sequential(*layers)
82 |
83 | def get_feat_modules(self):
84 | feat_m = nn.ModuleList([])
85 | feat_m.append(self.conv1)
86 | feat_m.append(self.bn1)
87 | feat_m.append(self.layer1)
88 | feat_m.append(self.layer2)
89 | feat_m.append(self.layer3)
90 | return feat_m
91 |
92 | def get_bn_before_relu(self):
93 | raise NotImplementedError('ShuffleNet currently is not supported for "Overhaul" teacher')
94 |
95 | def forward(self, x, is_feat=False, preact=False):
96 | out = F.relu(self.bn1(self.conv1(x)))
97 | f0 = out
98 | out, f1_pre = self.layer1(out)
99 | f1 = out
100 | out, f2_pre = self.layer2(out)
101 | f2 = out
102 | out, f3_pre = self.layer3(out)
103 | f3 = out
104 | out = F.avg_pool2d(out, 4)
105 | out = out.view(out.size(0), -1)
106 | f4 = out
107 | out = self.linear(out)
108 |
109 | if is_feat:
110 | if preact:
111 | return [f0, f1_pre, f2_pre, f3_pre, f4], out
112 | else:
113 | return [f0, f1, f2, f3, f4], out
114 | else:
115 | return out
116 |
117 |
118 | def ShuffleV1(**kwargs):
119 | cfg = {
120 | 'out_planes': [240, 480, 960],
121 | 'num_blocks': [4, 8, 4],
122 | 'groups': 3
123 | }
124 | return ShuffleNet(cfg, **kwargs)
125 |
126 |
127 | if __name__ == '__main__':
128 |
129 | x = torch.randn(2, 3, 32, 32)
130 | net = ShuffleV1(num_classes=100)
131 | import time
132 | a = time.time()
133 | feats, logit = net(x, is_feat=True, preact=True)
134 | b = time.time()
135 | print(b - a)
136 | for f in feats:
137 | print(f.shape, f.min().item())
138 | print(logit.shape)
139 |
--------------------------------------------------------------------------------
/models/ShuffleNetv2.py:
--------------------------------------------------------------------------------
1 | '''ShuffleNetV2 in PyTorch.
2 | See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details.
3 | '''
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class ShuffleBlock(nn.Module):
10 | def __init__(self, groups=2):
11 | super(ShuffleBlock, self).__init__()
12 | self.groups = groups
13 |
14 | def forward(self, x):
15 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
16 | N, C, H, W = x.size()
17 | g = self.groups
18 | return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W)
19 |
20 |
21 | class SplitBlock(nn.Module):
22 | def __init__(self, ratio):
23 | super(SplitBlock, self).__init__()
24 | self.ratio = ratio
25 |
26 | def forward(self, x):
27 | c = int(x.size(1) * self.ratio)
28 | return x[:, :c, :, :], x[:, c:, :, :]
29 |
30 |
31 | class BasicBlock(nn.Module):
32 | def __init__(self, in_channels, split_ratio=0.5, is_last=False):
33 | super(BasicBlock, self).__init__()
34 | self.is_last = is_last
35 | self.split = SplitBlock(split_ratio)
36 | in_channels = int(in_channels * split_ratio)
37 | self.conv1 = nn.Conv2d(in_channels, in_channels,
38 | kernel_size=1, bias=False)
39 | self.bn1 = nn.BatchNorm2d(in_channels)
40 | self.conv2 = nn.Conv2d(in_channels, in_channels,
41 | kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False)
42 | self.bn2 = nn.BatchNorm2d(in_channels)
43 | self.conv3 = nn.Conv2d(in_channels, in_channels,
44 | kernel_size=1, bias=False)
45 | self.bn3 = nn.BatchNorm2d(in_channels)
46 | self.shuffle = ShuffleBlock()
47 |
48 | def forward(self, x):
49 | x1, x2 = self.split(x)
50 | out = F.relu(self.bn1(self.conv1(x2)))
51 | out = self.bn2(self.conv2(out))
52 | preact = self.bn3(self.conv3(out))
53 | out = F.relu(preact)
54 | # out = F.relu(self.bn3(self.conv3(out)))
55 | preact = torch.cat([x1, preact], 1)
56 | out = torch.cat([x1, out], 1)
57 | out = self.shuffle(out)
58 | if self.is_last:
59 | return out, preact
60 | else:
61 | return out
62 |
63 |
64 | class DownBlock(nn.Module):
65 | def __init__(self, in_channels, out_channels):
66 | super(DownBlock, self).__init__()
67 | mid_channels = out_channels // 2
68 | # left
69 | self.conv1 = nn.Conv2d(in_channels, in_channels,
70 | kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False)
71 | self.bn1 = nn.BatchNorm2d(in_channels)
72 | self.conv2 = nn.Conv2d(in_channels, mid_channels,
73 | kernel_size=1, bias=False)
74 | self.bn2 = nn.BatchNorm2d(mid_channels)
75 | # right
76 | self.conv3 = nn.Conv2d(in_channels, mid_channels,
77 | kernel_size=1, bias=False)
78 | self.bn3 = nn.BatchNorm2d(mid_channels)
79 | self.conv4 = nn.Conv2d(mid_channels, mid_channels,
80 | kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False)
81 | self.bn4 = nn.BatchNorm2d(mid_channels)
82 | self.conv5 = nn.Conv2d(mid_channels, mid_channels,
83 | kernel_size=1, bias=False)
84 | self.bn5 = nn.BatchNorm2d(mid_channels)
85 |
86 | self.shuffle = ShuffleBlock()
87 |
88 | def forward(self, x):
89 | # left
90 | out1 = self.bn1(self.conv1(x))
91 | out1 = F.relu(self.bn2(self.conv2(out1)))
92 | # right
93 | out2 = F.relu(self.bn3(self.conv3(x)))
94 | out2 = self.bn4(self.conv4(out2))
95 | out2 = F.relu(self.bn5(self.conv5(out2)))
96 | # concat
97 | out = torch.cat([out1, out2], 1)
98 | out = self.shuffle(out)
99 | return out
100 |
101 |
102 | class ShuffleNetV2(nn.Module):
103 | def __init__(self, net_size, num_classes=10):
104 | super(ShuffleNetV2, self).__init__()
105 | out_channels = configs[net_size]['out_channels']
106 | num_blocks = configs[net_size]['num_blocks']
107 |
108 | # self.conv1 = nn.Conv2d(3, 24, kernel_size=3,
109 | # stride=1, padding=1, bias=False)
110 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False)
111 | self.bn1 = nn.BatchNorm2d(24)
112 | self.in_channels = 24
113 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0])
114 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1])
115 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2])
116 | self.conv2 = nn.Conv2d(out_channels[2], out_channels[3],
117 | kernel_size=1, stride=1, padding=0, bias=False)
118 | self.bn2 = nn.BatchNorm2d(out_channels[3])
119 | self.linear = nn.Linear(out_channels[3], num_classes)
120 |
121 | def _make_layer(self, out_channels, num_blocks):
122 | layers = [DownBlock(self.in_channels, out_channels)]
123 | for i in range(num_blocks):
124 | layers.append(BasicBlock(out_channels, is_last=(i == num_blocks - 1)))
125 | self.in_channels = out_channels
126 | return nn.Sequential(*layers)
127 |
128 | def get_feat_modules(self):
129 | feat_m = nn.ModuleList([])
130 | feat_m.append(self.conv1)
131 | feat_m.append(self.bn1)
132 | feat_m.append(self.layer1)
133 | feat_m.append(self.layer2)
134 | feat_m.append(self.layer3)
135 | return feat_m
136 |
137 | def get_bn_before_relu(self):
138 | raise NotImplementedError('ShuffleNetV2 currently is not supported for "Overhaul" teacher')
139 |
140 | def forward(self, x, is_feat=False, preact=False):
141 | out = F.relu(self.bn1(self.conv1(x)))
142 | # out = F.max_pool2d(out, 3, stride=2, padding=1)
143 | f0 = out
144 | out, f1_pre = self.layer1(out)
145 | f1 = out
146 | out, f2_pre = self.layer2(out)
147 | f2 = out
148 | out, f3_pre = self.layer3(out)
149 | f3 = out
150 | out = F.relu(self.bn2(self.conv2(out)))
151 | out = F.avg_pool2d(out, 4)
152 | out = out.view(out.size(0), -1)
153 | f4 = out
154 | out = self.linear(out)
155 | if is_feat:
156 | if preact:
157 | return [f0, f1_pre, f2_pre, f3_pre, f4], out
158 | else:
159 | return [f0, f1, f2, f3, f4], out
160 | else:
161 | return out
162 |
163 |
164 | configs = {
165 | 0.2: {
166 | 'out_channels': (40, 80, 160, 512),
167 | 'num_blocks': (3, 3, 3)
168 | },
169 |
170 | 0.3: {
171 | 'out_channels': (40, 80, 160, 512),
172 | 'num_blocks': (3, 7, 3)
173 | },
174 |
175 | 0.5: {
176 | 'out_channels': (48, 96, 192, 1024),
177 | 'num_blocks': (3, 7, 3)
178 | },
179 |
180 | 1: {
181 | 'out_channels': (116, 232, 464, 1024),
182 | 'num_blocks': (3, 7, 3)
183 | },
184 | 1.5: {
185 | 'out_channels': (176, 352, 704, 1024),
186 | 'num_blocks': (3, 7, 3)
187 | },
188 | 2: {
189 | 'out_channels': (224, 488, 976, 2048),
190 | 'num_blocks': (3, 7, 3)
191 | }
192 | }
193 |
194 |
195 | def ShuffleV2_0_2(**kwargs):
196 | model = ShuffleNetV2(net_size=0.2, **kwargs)
197 | return model
198 |
199 | def ShuffleV2_0_5(**kwargs):
200 | model = ShuffleNetV2(net_size=0.5, **kwargs)
201 | return model
202 |
203 | def ShuffleV2(**kwargs):
204 | model = ShuffleNetV2(net_size=1, **kwargs)
205 | return model
206 |
207 | def ShuffleV2_1_5(**kwargs):
208 | model = ShuffleNetV2(net_size=1.5, **kwargs)
209 | return model
210 |
211 | def ShuffleV2_2_0(**kwargs):
212 | model = ShuffleNetV2(net_size=2.0, **kwargs)
213 | return model
214 |
215 | if __name__ == '__main__':
216 | net = ShuffleV2(num_classes=100)
217 | x = torch.randn(3, 3, 32, 32)
218 | import time
219 | a = time.time()
220 | feats, logit = net(x, is_feat=True, preact=True)
221 | b = time.time()
222 | print(b - a)
223 | for f in feats:
224 | print(f.shape, f.min().item())
225 | print(logit.shape)
226 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import resnet38, resnet110, resnet116, resnet14x2, resnet38x2, resnet110x2
2 | from .resnet import resnet8x4, resnet14x4, resnet32x4, resnet38x4
3 | from .vgg import vgg8_bn, vgg13_bn
4 | from .mobilenetv2 import mobile_half, mobile_half_double
5 | from .ShuffleNetv1 import ShuffleV1
6 | from .ShuffleNetv2 import ShuffleV2, ShuffleV2_1_5
7 |
8 | from .resnet_imagenet import resnet18, resnet34, resnet50, wide_resnet50_2, resnext50_32x4d
9 | from .resnet_imagenet import wide_resnet10_2, wide_resnet18_2, wide_resnet34_2
10 | from .mobilenetv2_imagenet import mobilenet_v2
11 | from .shuffleNetv2_imagenet import shufflenet_v2_x1_0
12 |
13 | model_dict = {
14 | 'resnet38': resnet38,
15 | 'resnet110': resnet110,
16 | 'resnet116': resnet116,
17 | 'resnet14x2': resnet14x2,
18 | 'resnet38x2': resnet38x2,
19 | 'resnet110x2': resnet110x2,
20 | 'resnet8x4': resnet8x4,
21 | 'resnet14x4': resnet14x4,
22 | 'resnet32x4': resnet32x4,
23 | 'resnet38x4': resnet38x4,
24 | 'vgg8': vgg8_bn,
25 | 'vgg13': vgg13_bn,
26 | 'MobileNetV2': mobile_half,
27 | 'MobileNetV2_1_0': mobile_half_double,
28 | 'ShuffleV1': ShuffleV1,
29 | 'ShuffleV2': ShuffleV2,
30 | 'ShuffleV2_1_5': ShuffleV2_1_5,
31 |
32 | 'ResNet18': resnet18,
33 | 'ResNet34': resnet34,
34 | 'ResNet50': resnet50,
35 | 'resnext50_32x4d': resnext50_32x4d,
36 | 'ResNet10x2': wide_resnet10_2,
37 | 'ResNet18x2': wide_resnet18_2,
38 | 'ResNet34x2': wide_resnet34_2,
39 | 'wrn_50_2': wide_resnet50_2,
40 |
41 | 'MobileNetV2_Imagenet': mobilenet_v2,
42 | 'ShuffleV2_Imagenet': shufflenet_v2_x1_0,
43 | }
44 |
--------------------------------------------------------------------------------
/models/__pycache__/ShuffleNetv1.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/ShuffleNetv1.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/ShuffleNetv1.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/ShuffleNetv1.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/ShuffleNetv2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/ShuffleNetv2.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/ShuffleNetv2.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/ShuffleNetv2.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/mobilenetv2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/mobilenetv2.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/mobilenetv2.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/mobilenetv2.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/mobilenetv2_imagenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/mobilenetv2_imagenet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/mobilenetv2_imagenet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/mobilenetv2_imagenet.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/resnet.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet_imagenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/resnet_imagenet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet_imagenet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/resnet_imagenet.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/shuffleNetv2_imagenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/shuffleNetv2_imagenet.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/shuffleNetv2_imagenet.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/shuffleNetv2_imagenet.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/util.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/util.cpython-39.pyc
--------------------------------------------------------------------------------
/models/__pycache__/vgg.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/vgg.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/vgg.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DefangChen/SimKD/2b389c31ed7779aea31e7aaf0bb0f2d8b6ac2f01/models/__pycache__/vgg.cpython-39.pyc
--------------------------------------------------------------------------------
/models/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | """
2 | MobileNetV2 implementation used in
3 |
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | import math
9 |
10 | __all__ = ['mobilenetv2_T_w', 'mobile_half']
11 |
12 | BN = None
13 |
14 |
15 | def conv_bn(inp, oup, stride):
16 | return nn.Sequential(
17 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
18 | nn.BatchNorm2d(oup),
19 | nn.ReLU(inplace=True)
20 | )
21 |
22 |
23 | def conv_1x1_bn(inp, oup):
24 | return nn.Sequential(
25 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
26 | nn.BatchNorm2d(oup),
27 | nn.ReLU(inplace=True)
28 | )
29 |
30 |
31 | class InvertedResidual(nn.Module):
32 | def __init__(self, inp, oup, stride, expand_ratio):
33 | super(InvertedResidual, self).__init__()
34 | self.blockname = None
35 |
36 | self.stride = stride
37 | assert stride in [1, 2]
38 |
39 | self.use_res_connect = self.stride == 1 and inp == oup
40 |
41 | self.conv = nn.Sequential(
42 | # pw
43 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False),
44 | nn.BatchNorm2d(inp * expand_ratio),
45 | nn.ReLU(inplace=True),
46 | # dw
47 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, stride, 1, groups=inp * expand_ratio, bias=False),
48 | nn.BatchNorm2d(inp * expand_ratio),
49 | nn.ReLU(inplace=True),
50 | # pw-linear
51 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False),
52 | nn.BatchNorm2d(oup),
53 | )
54 | self.names = ['0', '1', '2', '3', '4', '5', '6', '7']
55 |
56 | def forward(self, x):
57 | t = x
58 | if self.use_res_connect:
59 | return t + self.conv(x)
60 | else:
61 | return self.conv(x)
62 |
63 |
64 | class MobileNetV2(nn.Module):
65 | """mobilenetV2"""
66 | def __init__(self, T,
67 | feature_dim,
68 | input_size=32,
69 | width_mult=1.,
70 | remove_avg=False):
71 | super(MobileNetV2, self).__init__()
72 | self.remove_avg = remove_avg
73 |
74 | # setting of inverted residual blocks
75 | self.interverted_residual_setting = [
76 | # t, c, n, s
77 | [1, 16, 1, 1],
78 | [T, 24, 2, 1],
79 | [T, 32, 3, 2],
80 | [T, 64, 4, 2],
81 | [T, 96, 3, 1],
82 | [T, 160, 3, 2],
83 | [T, 320, 1, 1],
84 | ]
85 |
86 | # building first layer
87 | assert input_size % 32 == 0
88 | input_channel = int(32 * width_mult)
89 | self.conv1 = conv_bn(3, input_channel, 2)
90 |
91 | # building inverted residual blocks
92 | self.blocks = nn.ModuleList([])
93 | for t, c, n, s in self.interverted_residual_setting:
94 | output_channel = int(c * width_mult)
95 | layers = []
96 | strides = [s] + [1] * (n - 1)
97 | for stride in strides:
98 | layers.append(
99 | InvertedResidual(input_channel, output_channel, stride, t)
100 | )
101 | input_channel = output_channel
102 | self.blocks.append(nn.Sequential(*layers))
103 |
104 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280
105 | self.conv2 = conv_1x1_bn(input_channel, self.last_channel)
106 |
107 | # building classifier
108 | self.classifier = nn.Sequential(
109 | # nn.Dropout(0.5),
110 | nn.Linear(self.last_channel, feature_dim),
111 | )
112 |
113 | H = input_size // (32//2)
114 | self.avgpool = nn.AvgPool2d(H, ceil_mode=True)
115 |
116 | self._initialize_weights()
117 | print(T, width_mult)
118 |
119 | def get_feat_modules(self):
120 | feat_m = nn.ModuleList([])
121 | feat_m.append(self.conv1)
122 | feat_m.append(self.blocks)
123 | return feat_m
124 |
125 | def forward(self, x, is_feat=False, preact=False):
126 |
127 | out = self.conv1(x)
128 | f0 = out
129 |
130 | out = self.blocks[0](out)
131 | out = self.blocks[1](out)
132 | f1 = out
133 | out = self.blocks[2](out)
134 | f2 = out
135 | out = self.blocks[3](out)
136 | out = self.blocks[4](out)
137 | f3 = out
138 | out = self.blocks[5](out)
139 | out = self.blocks[6](out)
140 | f4 = out
141 |
142 | out = self.conv2(out)
143 |
144 | if not self.remove_avg:
145 | out = self.avgpool(out)
146 | out = out.view(out.size(0), -1)
147 | f5 = out
148 | out = self.classifier(out)
149 |
150 | if is_feat:
151 | return [f0, f1, f2, f3, f4, f5], out
152 | else:
153 | return out
154 |
155 | def _initialize_weights(self):
156 | for m in self.modules():
157 | if isinstance(m, nn.Conv2d):
158 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
159 | m.weight.data.normal_(0, math.sqrt(2. / n))
160 | if m.bias is not None:
161 | m.bias.data.zero_()
162 | elif isinstance(m, nn.BatchNorm2d):
163 | m.weight.data.fill_(1)
164 | m.bias.data.zero_()
165 | elif isinstance(m, nn.Linear):
166 | n = m.weight.size(1)
167 | m.weight.data.normal_(0, 0.01)
168 | m.bias.data.zero_()
169 |
170 |
171 | def mobilenetv2_T_w(T, W, feature_dim=100):
172 | model = MobileNetV2(T=T, feature_dim=feature_dim, width_mult=W)
173 | return model
174 |
175 | # To be consistent with the previous paper (CRD), MobileNetV2 is instantiated by mobile_half
176 | def mobile_half(num_classes):
177 | return mobilenetv2_T_w(6, 0.5, num_classes)
178 |
179 | # MobileNetV2x2 is instantiated by mobile_half_double
180 | def mobile_half_double(num_classes):
181 | return mobilenetv2_T_w(6, 1.0, num_classes)
182 |
183 | if __name__ == '__main__':
184 | x = torch.randn(2, 3, 32, 32)
185 |
186 | net = mobile_half(100)
187 |
188 | feats, logit = net(x, is_feat=True, preact=True)
189 | for f in feats:
190 | print(f.shape, f.min().item())
191 | print(logit.shape)
192 |
193 | num_params_stu = (sum(p.numel() for p in net.parameters())/1000000.0)
194 | print('Total params_stu: {:.3f} M'.format(num_params_stu))
195 |
--------------------------------------------------------------------------------
/models/mobilenetv2_imagenet.py:
--------------------------------------------------------------------------------
1 | # Copied from https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenet.py
2 | # Only two changes:
3 | # 1. MobileNetV2.forward() is modified to return inner feature maps.
4 | # 2. merge utils.py into this file to import load_state_dict_from_url.
5 |
6 | import torch
7 | from torch import nn
8 |
9 | # https://github.com/pytorch/vision/blob/master/torchvision/models/utils.py
10 | try:
11 | from torch.hub import load_state_dict_from_url
12 | except ImportError:
13 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
14 |
15 |
16 | __all__ = ['MobileNetV2', 'mobilenet_v2']
17 |
18 |
19 | model_urls = {
20 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
21 | }
22 |
23 |
24 | def _make_divisible(v, divisor, min_value=None):
25 | """
26 | This function is taken from the original tf repo.
27 | It ensures that all layers have a channel number that is divisible by 8
28 | It can be seen here:
29 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
30 | :param v:
31 | :param divisor:
32 | :param min_value:
33 | :return:
34 | """
35 | if min_value is None:
36 | min_value = divisor
37 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
38 | # Make sure that round down does not go down by more than 10%.
39 | if new_v < 0.9 * v:
40 | new_v += divisor
41 | return new_v
42 |
43 |
44 | class ConvBNReLU(nn.Sequential):
45 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, norm_layer=None):
46 | padding = (kernel_size - 1) // 2
47 | if norm_layer is None:
48 | norm_layer = nn.BatchNorm2d
49 | super(ConvBNReLU, self).__init__(
50 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
51 | norm_layer(out_planes),
52 | nn.ReLU6(inplace=True)
53 | )
54 |
55 |
56 | class InvertedResidual(nn.Module):
57 | def __init__(self, inp, oup, stride, expand_ratio, norm_layer=None):
58 | super(InvertedResidual, self).__init__()
59 | self.stride = stride
60 | assert stride in [1, 2]
61 |
62 | if norm_layer is None:
63 | norm_layer = nn.BatchNorm2d
64 |
65 | hidden_dim = int(round(inp * expand_ratio))
66 | self.use_res_connect = self.stride == 1 and inp == oup
67 |
68 | layers = []
69 | if expand_ratio != 1:
70 | # pw
71 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
72 | layers.extend([
73 | # dw
74 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
75 | # pw-linear
76 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
77 | norm_layer(oup),
78 | ])
79 | self.conv = nn.Sequential(*layers)
80 |
81 | def forward(self, x):
82 | if self.use_res_connect:
83 | return x + self.conv(x)
84 | else:
85 | return self.conv(x)
86 |
87 |
88 | class MobileNetV2(nn.Module):
89 | def __init__(self,
90 | num_classes=1000,
91 | width_mult=1.0,
92 | inverted_residual_setting=None,
93 | round_nearest=8,
94 | block=None,
95 | norm_layer=None):
96 | """
97 | MobileNet V2 main class
98 |
99 | Args:
100 | num_classes (int): Number of classes
101 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
102 | inverted_residual_setting: Network structure
103 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number
104 | Set to 1 to turn off rounding
105 | block: Module specifying inverted residual building block for mobilenet
106 | norm_layer: Module specifying the normalization layer to use
107 |
108 | """
109 | super(MobileNetV2, self).__init__()
110 |
111 | if block is None:
112 | block = InvertedResidual
113 |
114 | if norm_layer is None:
115 | norm_layer = nn.BatchNorm2d
116 |
117 | input_channel = 32
118 | last_channel = 1280
119 |
120 | if inverted_residual_setting is None:
121 | inverted_residual_setting = [
122 | # t, c, n, s
123 | [1, 16, 1, 1],
124 | [6, 24, 2, 2],
125 | [6, 32, 3, 2],
126 | [6, 64, 4, 2],
127 | [6, 96, 3, 1],
128 | [6, 160, 3, 2],
129 | [6, 320, 1, 1],
130 | ]
131 |
132 | # only check the first element, assuming user knows t,c,n,s are required
133 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
134 | raise ValueError("inverted_residual_setting should be non-empty "
135 | "or a 4-element list, got {}".format(inverted_residual_setting))
136 |
137 | # building first layer
138 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
139 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
140 | features = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
141 | # building inverted residual blocks
142 | for t, c, n, s in inverted_residual_setting:
143 | output_channel = _make_divisible(c * width_mult, round_nearest)
144 | for i in range(n):
145 | stride = s if i == 0 else 1
146 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer))
147 | input_channel = output_channel
148 | # building last several layers
149 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
150 | # do not use nn.Sequential
151 | # self.features = nn.Sequential(*features)
152 | self.features = nn.ModuleList(features)
153 |
154 | # building classifier
155 | self.classifier = nn.Sequential(
156 | nn.Dropout(0.2),
157 | nn.Linear(self.last_channel, num_classes),
158 | )
159 |
160 | # weight initialization
161 | for m in self.modules():
162 | if isinstance(m, nn.Conv2d):
163 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
164 | if m.bias is not None:
165 | nn.init.zeros_(m.bias)
166 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
167 | nn.init.ones_(m.weight)
168 | nn.init.zeros_(m.bias)
169 | elif isinstance(m, nn.Linear):
170 | nn.init.normal_(m.weight, 0, 0.01)
171 | nn.init.zeros_(m.bias)
172 |
173 | def _forward_impl(self, x):
174 | # This exists since TorchScript doesn't support inheritance, so the superclass method
175 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass
176 | # x = self.features(x)
177 | for module in self.features:
178 | x = module(x)
179 | # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
180 | x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
181 | x = self.classifier(x)
182 | return x
183 |
184 | def forward(self, x, is_feat=False):
185 | if not is_feat:
186 | return self._forward_impl(x)
187 |
188 | splits = [0, 1, 4, 7, 14, 18]
189 | hidden_layers = []
190 | for left, right in zip(splits, splits[1:]):
191 | for module in self.features[left:right]:
192 | x = module(x)
193 | hidden_layers.append(x)
194 | for module in self.features[splits[-1]:]:
195 | x = module(x)
196 | x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
197 | hidden_layers.append(x)
198 | x = self.classifier(x)
199 | return hidden_layers, x
200 |
201 | def mobilenet_v2(pretrained=False, progress=True, **kwargs):
202 | """
203 | Constructs a MobileNetV2 architecture from
204 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_.
205 |
206 | Args:
207 | pretrained (bool): If True, returns a model pre-trained on ImageNet
208 | progress (bool): If True, displays a progress bar of the download to stderr
209 | """
210 | model = MobileNetV2(**kwargs)
211 | if pretrained:
212 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
213 | progress=progress)
214 | model.load_state_dict(state_dict)
215 | return model
216 |
217 | if __name__ == '__main__':
218 | #x = torch.randn(2, 3, 32, 32)
219 | x = torch.randn(2, 3, 224, 224)
220 |
221 | #net = mobile_half(100)
222 | net = MobileNetV2()
223 |
224 | feats, logit = net(x, is_feat=True)
225 | for f in feats:
226 | print(f.shape, f.min().item())
227 | print(logit.shape)
228 |
229 | num_params_stu = (sum(p.numel() for p in net.parameters())/1000000.0)
230 | print('Total params_stu: {:.3f} M'.format(num_params_stu))
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | '''
4 | ResNet for CIFAR-10/100 Dataset (Only BasicBlock is used).
5 | Reference:
6 | 1. https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
7 | 2. https://github.com/facebook/fb.resnet.torch/blob/master/models/resnet.lua
8 | 3. Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
9 | Deep Residual Learning for Image Recognition (CVPR 2016). https://arxiv.org/abs/1512.03385
10 |
11 | The Wide ResNet model is the same as ResNet except for the number of channels
12 | is double/quadruple/k-time larger in every basicblock.
13 | Reference:
14 | 1. https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
15 | 2. Sergey Zagoruyko and Nikos Komodakis
16 | Wide Residual Networks (BMVC 2016) http://arxiv.org/abs/1605.07146
17 |
18 | P.S.
19 | Following the previous repository "https://github.com/HobbitLong/RepDistiller", the num_filters of the first conv is doubled in ResNet-8x4/32x4.
20 | The wide ResNet model in the "/RepDistiller/models/wrn.py" is almost the same as ResNet model in "/RepDistiller/models/resnet.py".
21 | For example, wrn_40_2 in "/RepDistiller/models/wrn.py" almost equals to resnet38x2 in "/RepDistiller/models/resnet.py".
22 | The only difference is that resnet38x2 has additional three BN layers, which leads to 2*(16+32+64)*k parameters [k=2 in this comparison].
23 | Therefore, it is recommanded to directly use this file for the implementation of the Wide ResNet model.
24 | '''
25 |
26 | import torch.nn as nn
27 |
28 | __all__ = ['resnet']
29 |
30 | def conv3x3(in_planes, out_planes, stride=1):
31 | """3x3 convolution with padding"""
32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
33 |
34 | def conv1x1(in_planes, out_planes, stride=1):
35 | """1x1 convolution"""
36 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
37 |
38 | class BasicBlock(nn.Module):
39 | expansion = 1
40 |
41 | def __init__(self, inplanes, planes, stride=1, downsample=None):
42 | super(BasicBlock, self).__init__()
43 | self.conv1 = conv3x3(inplanes, planes, stride)
44 | self.bn1 = nn.BatchNorm2d(planes)
45 | self.relu = nn.ReLU(inplace=True)
46 | self.conv2 = conv3x3(planes, planes)
47 | self.bn2 = nn.BatchNorm2d(planes)
48 | self.downsample = downsample
49 | self.stride = stride
50 |
51 | def forward(self, x):
52 | residual = x
53 |
54 | out = self.conv1(x)
55 | out = self.bn1(out)
56 | out = self.relu(out)
57 |
58 | out = self.conv2(out)
59 | out = self.bn2(out)
60 |
61 | if self.downsample is not None:
62 | residual = self.downsample(x)
63 |
64 | out += residual
65 | out = self.relu(out)
66 | return out
67 |
68 | class Bottleneck(nn.Module):
69 | expansion = 4
70 |
71 | def __init__(self, inplanes, planes, stride=1, downsample=None):
72 | super(Bottleneck, self).__init__()
73 | self.conv1 = conv1x1(inplanes, planes)
74 | self.bn1 = nn.BatchNorm2d(planes)
75 | self.conv2 = conv3x3(planes, planes, stride=1)
76 | self.bn2 = nn.BatchNorm2d(planes)
77 | self.conv3 = conv1x1(planes, planes * 4)
78 | self.bn3 = nn.BatchNorm2d(planes * 4)
79 | self.relu = nn.ReLU(inplace=True)
80 | self.downsample = downsample
81 | self.stride = stride
82 |
83 | def forward(self, x):
84 | residual = x
85 |
86 | out = self.conv1(x)
87 | out = self.bn1(out)
88 | out = self.relu(out)
89 |
90 | out = self.conv2(out)
91 | out = self.bn2(out)
92 | out = self.relu(out)
93 |
94 | out = self.conv3(out)
95 | out = self.bn3(out)
96 |
97 | if self.downsample is not None:
98 | residual = self.downsample(x)
99 |
100 | out += residual
101 | out = self.relu(out)
102 | return out
103 |
104 | class ResNet(nn.Module):
105 |
106 | def __init__(self, depth, num_filters, block_name='BasicBlock', num_classes=10):
107 | super(ResNet, self).__init__()
108 | # Model type specifies number of layers for CIFAR-10 model
109 | if block_name.lower() == 'basicblock':
110 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202'
111 | n = (depth - 2) // 6
112 | block = BasicBlock
113 | elif block_name.lower() == 'bottleneck':
114 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199'
115 | n = (depth - 2) // 9
116 | block = Bottleneck
117 | else:
118 | raise ValueError('block_name shoule be Basicblock or Bottleneck')
119 |
120 | self.inplanes = num_filters[0]
121 | self.conv1 = nn.Conv2d(3, num_filters[0], kernel_size=3, padding=1, bias=False)
122 | self.bn1 = nn.BatchNorm2d(num_filters[0])
123 | self.relu = nn.ReLU(inplace=True)
124 | self.layer1 = self._make_layer(block, num_filters[1], n)
125 | self.layer2 = self._make_layer(block, num_filters[2], n, stride=2)
126 | self.layer3 = self._make_layer(block, num_filters[3], n, stride=2)
127 | self.avgpool = nn.AdaptiveAvgPool2d((1,1))
128 | self.fc = nn.Linear(num_filters[3] * block.expansion, num_classes)
129 |
130 | for m in self.modules():
131 | if isinstance(m, nn.Conv2d):
132 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
133 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
134 | nn.init.constant_(m.weight, 1)
135 | nn.init.constant_(m.bias, 0)
136 |
137 | def _make_layer(self, block, planes, blocks, stride=1):
138 | downsample = None
139 | if stride != 1 or self.inplanes != planes * block.expansion:
140 | downsample = nn.Sequential(
141 | conv1x1(self.inplanes, planes * block.expansion, stride),
142 | nn.BatchNorm2d(planes * block.expansion),
143 | )
144 |
145 | layers = list([])
146 | layers.append(block(self.inplanes, planes, stride, downsample))
147 | self.inplanes = planes * block.expansion
148 | for i in range(1, blocks):
149 | layers.append(block(self.inplanes, planes))
150 |
151 | return nn.Sequential(*layers)
152 |
153 | def get_feat_modules(self):
154 | feat_m = nn.ModuleList([])
155 | feat_m.append(self.conv1)
156 | feat_m.append(self.bn1)
157 | feat_m.append(self.relu)
158 | feat_m.append(self.layer1)
159 | feat_m.append(self.layer2)
160 | feat_m.append(self.layer3)
161 | feat_m.append(self.fc)
162 | return feat_m
163 |
164 | def forward(self, x, is_feat=False):
165 |
166 | x = self.conv1(x)
167 | x = self.bn1(x)
168 | x = self.relu(x) # 32x32
169 | f0 = x
170 |
171 | x = self.layer1(x) # 32x32
172 | f1 = x
173 | x = self.layer2(x) # 16x16
174 | f2 = x
175 | x = self.layer3(x) # 8x8
176 | f3 = x
177 |
178 | x = self.avgpool(x)
179 | x = x.view(x.size(0), -1)
180 | f4 = x
181 | x = self.fc(x)
182 |
183 | if is_feat:
184 | return [f0, f1, f2, f3, f4], x
185 | else:
186 | return x
187 |
188 | def resnet8(**kwargs):
189 | return ResNet(8, [16, 16, 32, 64], 'basicblock', **kwargs)
190 |
191 | def resnet14(**kwargs):
192 | return ResNet(14, [16, 16, 32, 64], 'basicblock', **kwargs)
193 |
194 | def resnet20(**kwargs):
195 | return ResNet(20, [16, 16, 32, 64], 'basicblock', **kwargs)
196 |
197 | def resnet32(**kwargs):
198 | return ResNet(32, [16, 16, 32, 64], 'basicblock', **kwargs)
199 |
200 | # wrn_40_1 (We use the wrn notation to be consistent with the previous work)
201 | def resnet38(**kwargs):
202 | return ResNet(38, [16, 16, 32, 64], 'basicblock', **kwargs)
203 |
204 | def resnet44(**kwargs):
205 | return ResNet(44, [16, 16, 32, 64], 'basicblock', **kwargs)
206 |
207 | def resnet56(**kwargs):
208 | return ResNet(56, [16, 16, 32, 64], 'basicblock', **kwargs)
209 |
210 | def resnet110(**kwargs):
211 | return ResNet(110, [16, 16, 32, 64], 'basicblock', **kwargs)
212 |
213 | def resnet116(**kwargs):
214 | return ResNet(116, [16, 16, 32, 64], 'basicblock', **kwargs)
215 |
216 | def resnet200(**kwargs):
217 | return ResNet(200, [16, 16, 32, 64], 'basicblock', **kwargs)
218 |
219 | # wrn_16_2 (We use the wrn notation to be consistent with the previous work)
220 | def resnet14x2(**kwargs):
221 | return ResNet(14, [16, 32, 64, 128], 'basicblock', **kwargs)
222 |
223 | # wrn_16_4 (We use the wrn notation to be consistent with the previous work)
224 | def resnet14x4(**kwargs):
225 | return ResNet(14, [32, 64, 128, 256], 'basicblock', **kwargs)
226 |
227 | # wrn_40_2 (We use the wrn notation to be consistent with the previous work)
228 | def resnet38x2(**kwargs):
229 | return ResNet(38, [16, 32, 64, 128], 'basicblock', **kwargs)
230 |
231 | def resnet110x2(**kwargs):
232 | return ResNet(110, [16, 32, 64, 128], 'basicblock', **kwargs)
233 |
234 | def resnet8x4(**kwargs):
235 | return ResNet(8, [32, 64, 128, 256], 'basicblock', **kwargs)
236 |
237 | def resnet20x4(**kwargs):
238 | return ResNet(20, [32, 64, 128, 256], 'basicblock', **kwargs)
239 |
240 | def resnet26x4(**kwargs):
241 | return ResNet(26, [32, 64, 128, 256], 'basicblock', **kwargs)
242 |
243 | def resnet32x4(**kwargs):
244 | return ResNet(32, [32, 64, 128, 256], 'basicblock', **kwargs)
245 |
246 | # wrn_40_4 (We use the wrn notation to be consistent with the previous work)
247 | def resnet38x4(**kwargs):
248 | return ResNet(38, [32, 64, 128, 256], 'basicblock', **kwargs)
249 |
250 | def resnet44x4(**kwargs):
251 | return ResNet(44, [32, 64, 128, 256], 'basicblock', **kwargs)
252 |
253 | def resnet56x4(**kwargs):
254 | return ResNet(56, [32, 64, 128, 256], 'basicblock', **kwargs)
255 |
256 | def resnet110x4(**kwargs):
257 | return ResNet(110, [32, 64, 128, 256], 'basicblock', **kwargs)
258 |
259 | if __name__ == '__main__':
260 | import torch
261 |
262 | x = torch.randn(2, 3, 32, 32)
263 | net = resnet32(num_classes=100)
264 | feats, logit = net(x, is_feat=True)
265 |
266 | for f in feats:
267 | print(f.shape, f.min().item())
268 | print(logit.shape)
269 |
270 | #for i, m in enumerate(net.get_feat_modules()):
271 | # print(i, m)
272 |
273 | num_params_stu = (sum(p.numel() for p in net.parameters())/1000000.0)
274 | print('Total params_stu: {:.3f} M'.format(num_params_stu))
--------------------------------------------------------------------------------
/models/resnet_imagenet.py:
--------------------------------------------------------------------------------
1 | # Copied from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
2 | # Only two changes:
3 | # 1. Resnet.forward() is modified to return inner feature maps.
4 | # 2. merge utils.py into this file to import load_state_dict_from_url.
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | # https://github.com/pytorch/vision/blob/master/torchvision/models/utils.py
10 | try:
11 | from torch.hub import load_state_dict_from_url
12 | except ImportError:
13 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
14 |
15 |
16 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
17 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
18 | 'wide_resnet50_2', 'wide_resnet101_2', 'resnet34x4']
19 |
20 |
21 | model_urls = {
22 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
23 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
24 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
25 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
26 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
27 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
28 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
29 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
30 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
31 | }
32 |
33 |
34 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
35 | """3x3 convolution with padding"""
36 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
37 | padding=dilation, groups=groups, bias=False, dilation=dilation)
38 |
39 |
40 | def conv1x1(in_planes, out_planes, stride=1):
41 | """1x1 convolution"""
42 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
43 |
44 |
45 | class BasicBlock(nn.Module):
46 | expansion = 1
47 |
48 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
49 | base_width=64, dilation=1, norm_layer=None):
50 | super(BasicBlock, self).__init__()
51 | if norm_layer is None:
52 | norm_layer = nn.BatchNorm2d
53 | # if groups != 1 or base_width != 64:
54 | # raise ValueError('BasicBlock only supports groups=1 and base_width=64')
55 | if dilation > 1:
56 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
57 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
58 | self.conv1 = conv3x3(inplanes, planes, stride)
59 | self.bn1 = norm_layer(planes)
60 | self.relu = nn.ReLU(inplace=True)
61 | self.conv2 = conv3x3(planes, planes)
62 | self.bn2 = norm_layer(planes)
63 | self.downsample = downsample
64 | self.stride = stride
65 |
66 | def forward(self, x):
67 | identity = x
68 |
69 | out = self.conv1(x)
70 | out = self.bn1(out)
71 | out = self.relu(out)
72 |
73 | out = self.conv2(out)
74 | out = self.bn2(out)
75 |
76 | if self.downsample is not None:
77 | identity = self.downsample(x)
78 |
79 | out += identity
80 | out = self.relu(out)
81 |
82 | return out
83 |
84 |
85 | class Bottleneck(nn.Module):
86 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
87 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
88 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
89 | # This variant is also known as ResNet V1.5 and improves accuracy according to
90 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
91 |
92 | expansion = 4
93 |
94 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
95 | base_width=64, dilation=1, norm_layer=None):
96 | super(Bottleneck, self).__init__()
97 | if norm_layer is None:
98 | norm_layer = nn.BatchNorm2d
99 | width = int(planes * (base_width / 64.)) * groups
100 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
101 | self.conv1 = conv1x1(inplanes, width)
102 | self.bn1 = norm_layer(width)
103 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
104 | self.bn2 = norm_layer(width)
105 | self.conv3 = conv1x1(width, planes * self.expansion)
106 | self.bn3 = norm_layer(planes * self.expansion)
107 | self.relu = nn.ReLU(inplace=True)
108 | self.downsample = downsample
109 | self.stride = stride
110 |
111 | def forward(self, x):
112 | identity = x
113 |
114 | out = self.conv1(x)
115 | out = self.bn1(out)
116 | out = self.relu(out)
117 |
118 | out = self.conv2(out)
119 | out = self.bn2(out)
120 | out = self.relu(out)
121 |
122 | out = self.conv3(out)
123 | out = self.bn3(out)
124 |
125 | if self.downsample is not None:
126 | identity = self.downsample(x)
127 |
128 | out += identity
129 | out = self.relu(out)
130 |
131 | return out
132 |
133 | class ResNet(nn.Module):
134 |
135 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
136 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
137 | norm_layer=None):
138 | super(ResNet, self).__init__()
139 | if norm_layer is None:
140 | norm_layer = nn.BatchNorm2d
141 | self._norm_layer = norm_layer
142 |
143 | self.inplanes = 64
144 | self.dilation = 1
145 | self.multiplier = 1
146 | if replace_stride_with_dilation is None:
147 | # each element in the tuple indicates if we should replace
148 | # the 2x2 stride with a dilated convolution instead
149 | replace_stride_with_dilation = [False, False, False]
150 | if len(replace_stride_with_dilation) != 3:
151 | raise ValueError("replace_stride_with_dilation should be None "
152 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
153 | self.groups = groups
154 | self.base_width = width_per_group
155 | if self.base_width != 64 and (block is BasicBlock):
156 | self.multiplier = self.base_width // 64
157 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
158 | bias=False)
159 | self.bn1 = norm_layer(self.inplanes)
160 | self.relu = nn.ReLU(inplace=True)
161 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
162 | self.layer1 = self._make_layer(block, int(64 * self.multiplier), layers[0])
163 | self.layer2 = self._make_layer(block, int(128 * self.multiplier), layers[1], stride=2,
164 | dilate=replace_stride_with_dilation[0])
165 | self.layer3 = self._make_layer(block, int(256 * self.multiplier), layers[2], stride=2,
166 | dilate=replace_stride_with_dilation[1])
167 | self.layer4 = self._make_layer(block, int(512 * self.multiplier), layers[3], stride=2,
168 | dilate=replace_stride_with_dilation[2])
169 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
170 | self.fc = nn.Linear(int(512 * self.multiplier) * block.expansion, num_classes)
171 |
172 | for m in self.modules():
173 | if isinstance(m, nn.Conv2d):
174 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
175 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
176 | nn.init.constant_(m.weight, 1)
177 | nn.init.constant_(m.bias, 0)
178 |
179 | # Zero-initialize the last BN in each residual branch,
180 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
181 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
182 | if zero_init_residual:
183 | for m in self.modules():
184 | if isinstance(m, Bottleneck):
185 | nn.init.constant_(m.bn3.weight, 0)
186 | elif isinstance(m, BasicBlock):
187 | nn.init.constant_(m.bn2.weight, 0)
188 |
189 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
190 | norm_layer = self._norm_layer
191 | downsample = None
192 | previous_dilation = self.dilation
193 | if dilate:
194 | self.dilation *= stride
195 | stride = 1
196 |
197 | if stride != 1 or self.inplanes != planes * block.expansion:
198 | downsample = nn.Sequential(
199 | conv1x1(self.inplanes, planes * block.expansion, stride),
200 | norm_layer(planes * block.expansion),
201 | )
202 |
203 | layers = []
204 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
205 | self.base_width, previous_dilation, norm_layer))
206 | self.inplanes = planes * block.expansion
207 | for _ in range(1, blocks):
208 | layers.append(block(self.inplanes, planes, groups=self.groups,
209 | base_width=self.base_width, dilation=self.dilation,
210 | norm_layer=norm_layer))
211 |
212 | return nn.Sequential(*layers)
213 |
214 | def get_feat_modules(self):
215 | feat_m = nn.ModuleList([])
216 | feat_m.append(self.conv1)
217 | feat_m.append(self.bn1)
218 | feat_m.append(self.relu)
219 | feat_m.append(self.maxpool)
220 | feat_m.append(self.layer1)
221 | feat_m.append(self.layer2)
222 | feat_m.append(self.layer3)
223 | feat_m.append(self.layer4)
224 | feat_m.append(self.fc)
225 | return feat_m
226 |
227 | def forward(self, x, is_feat=False):
228 | x = self.conv1(x)
229 | x = self.bn1(x)
230 | x = self.relu(x)
231 | x = self.maxpool(x)
232 | f0 = x
233 |
234 | x = self.layer1(x)
235 | f1 = x
236 | x = self.layer2(x)
237 | f2 = x
238 | x = self.layer3(x)
239 | f3 = x
240 | x = self.layer4(x)
241 | f4 = x
242 |
243 | x = self.avgpool(x)
244 | x = torch.flatten(x, 1)
245 | f5 = x
246 | x = self.fc(x)
247 | if is_feat:
248 | return [f0, f1, f2, f3, f4, f5], x
249 | else:
250 | return x
251 |
252 | def _resnet(arch, block, layers, pretrained, progress, **kwargs):
253 | model = ResNet(block, layers, **kwargs)
254 | if pretrained:
255 | state_dict = load_state_dict_from_url(model_urls[arch],
256 | progress=progress)
257 | model.load_state_dict(state_dict)
258 | return model
259 |
260 |
261 | def resnet18(pretrained=False, progress=True, **kwargs):
262 | r"""ResNet-18 model from
263 | `"Deep Residual Learning for Image Recognition" `_
264 |
265 | Args:
266 | pretrained (bool): If True, returns a model pre-trained on ImageNet
267 | progress (bool): If True, displays a progress bar of the download to stderr
268 | """
269 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
270 | **kwargs)
271 |
272 | def wide_resnet10_2(pretrained=False, progress=True, **kwargs):
273 |
274 | kwargs['width_per_group'] = 64 * 2
275 | return _resnet('wide_resnet10_2', BasicBlock, [1, 1, 1, 1], pretrained, progress,
276 | **kwargs)
277 |
278 | def wide_resnet18_2(pretrained=False, progress=True, **kwargs):
279 |
280 | kwargs['width_per_group'] = 64 * 2
281 | return _resnet('wide_resnet18_2', BasicBlock, [2, 2, 2, 2], pretrained, progress,
282 | **kwargs)
283 |
284 | def wide_resnet26_2(pretrained=False, progress=True, **kwargs):
285 |
286 | kwargs['width_per_group'] = 64 * 2
287 | return _resnet('wide_resnet26_2', BasicBlock, [3, 3, 3, 3], pretrained, progress,
288 | **kwargs)
289 |
290 | def resnet34(pretrained=False, progress=True, **kwargs):
291 | r"""ResNet-34 model from
292 | `"Deep Residual Learning for Image Recognition" `_
293 |
294 | Args:
295 | pretrained (bool): If True, returns a model pre-trained on ImageNet
296 | progress (bool): If True, displays a progress bar of the download to stderr
297 | """
298 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
299 | **kwargs)
300 |
301 | def resnet50(pretrained=False, progress=True, **kwargs):
302 | r"""ResNet-50 model from
303 | `"Deep Residual Learning for Image Recognition" `_
304 |
305 | Args:
306 | pretrained (bool): If True, returns a model pre-trained on ImageNet
307 | progress (bool): If True, displays a progress bar of the download to stderr
308 | """
309 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
310 | **kwargs)
311 |
312 | def resnet101(pretrained=False, progress=True, **kwargs):
313 | r"""ResNet-101 model from
314 | `"Deep Residual Learning for Image Recognition" `_
315 |
316 | Args:
317 | pretrained (bool): If True, returns a model pre-trained on ImageNet
318 | progress (bool): If True, displays a progress bar of the download to stderr
319 | """
320 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
321 | **kwargs)
322 |
323 | def resnet152(pretrained=False, progress=True, **kwargs):
324 | r"""ResNet-152 model from
325 | `"Deep Residual Learning for Image Recognition" `_
326 |
327 | Args:
328 | pretrained (bool): If True, returns a model pre-trained on ImageNet
329 | progress (bool): If True, displays a progress bar of the download to stderr
330 | """
331 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
332 | **kwargs)
333 |
334 | def wide_resnet34_2(pretrained=False, progress=True, **kwargs):
335 |
336 | kwargs['width_per_group'] = 64 * 2
337 | return _resnet('wide_resnet34_2', BasicBlock, [3, 4, 6, 3], pretrained, progress,
338 | **kwargs)
339 |
340 | def wide_resnet34_4(pretrained=False, progress=True, **kwargs):
341 |
342 | kwargs['width_per_group'] = 64 * 4
343 | return _resnet('wide_resnet34_4', BasicBlock, [3, 4, 6, 3], pretrained, progress,
344 | **kwargs)
345 |
346 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
347 | r"""ResNeXt-50 32x4d model from
348 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
349 |
350 | Args:
351 | pretrained (bool): If True, returns a model pre-trained on ImageNet
352 | progress (bool): If True, displays a progress bar of the download to stderr
353 | """
354 | kwargs['groups'] = 32
355 | kwargs['width_per_group'] = 4
356 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
357 | pretrained, progress, **kwargs)
358 |
359 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
360 | r"""ResNeXt-101 32x8d model from
361 | `"Aggregated Residual Transformation for Deep Neural Networks" `_
362 |
363 | Args:
364 | pretrained (bool): If True, returns a model pre-trained on ImageNet
365 | progress (bool): If True, displays a progress bar of the download to stderr
366 | """
367 | kwargs['groups'] = 32
368 | kwargs['width_per_group'] = 8
369 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
370 | pretrained, progress, **kwargs)
371 |
372 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
373 | r"""Wide ResNet-50-2 model from
374 | `"Wide Residual Networks" `_
375 |
376 | The model is the same as ResNet except for the bottleneck number of channels
377 | which is twice larger in every block. The number of channels in outer 1x1
378 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
379 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
380 |
381 | Args:
382 | pretrained (bool): If True, returns a model pre-trained on ImageNet
383 | progress (bool): If True, displays a progress bar of the download to stderr
384 | """
385 | kwargs['width_per_group'] = 64 * 2
386 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
387 | pretrained, progress, **kwargs)
388 |
389 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
390 | r"""Wide ResNet-101-2 model from
391 | `"Wide Residual Networks" `_
392 |
393 | The model is the same as ResNet except for the bottleneck number of channels
394 | which is twice larger in every block. The number of channels in outer 1x1
395 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
396 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
397 |
398 | Args:
399 | pretrained (bool): If True, returns a model pre-trained on ImageNet
400 | progress (bool): If True, displays a progress bar of the download to stderr
401 | """
402 | kwargs['width_per_group'] = 64 * 2
403 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
404 | pretrained, progress, **kwargs)
405 |
406 | if __name__ == '__main__':
407 |
408 | x = torch.randn(64, 3, 224, 224)
409 |
410 | net = wide_resnet10_2()
411 |
412 | feats, logit = net(x, is_feat=True)
413 | for f in feats:
414 | print(f.shape, f.min().item())
415 | print(logit.shape)
416 |
417 | num_params_stu = (sum(p.numel() for p in net.parameters())/1000000.0)
418 | print('Total params_stu: {:.3f} M'.format(num_params_stu))
--------------------------------------------------------------------------------
/models/shuffleNetv2_imagenet.py:
--------------------------------------------------------------------------------
1 | # Copied from https://github.com/pytorch/vision/blob/master/torchvision/models/shufflenetv2.py
2 | # Only two changes:
3 | # 1. ShuffleNet is modified to return inner feature maps.
4 | # 2. merge utils.py into this file to import load_state_dict_from_url.
5 |
6 | import torch
7 | import torch.nn as nn
8 |
9 | # https://github.com/pytorch/vision/blob/master/torchvision/models/utils.py
10 | try:
11 | from torch.hub import load_state_dict_from_url
12 | except ImportError:
13 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
14 |
15 |
16 |
17 | __all__ = [
18 | 'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
19 | 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'
20 | ]
21 |
22 | model_urls = {
23 | 'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth',
24 | 'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth',
25 | 'shufflenetv2_x1.5': None,
26 | 'shufflenetv2_x2.0': None,
27 | }
28 |
29 |
30 | def channel_shuffle(x, groups):
31 | # type: (torch.Tensor, int) -> torch.Tensor
32 | batchsize, num_channels, height, width = x.data.size()
33 | channels_per_group = num_channels // groups
34 |
35 | # reshape
36 | x = x.view(batchsize, groups,
37 | channels_per_group, height, width)
38 |
39 | x = torch.transpose(x, 1, 2).contiguous()
40 |
41 | # flatten
42 | x = x.view(batchsize, -1, height, width)
43 |
44 | return x
45 |
46 |
47 | class InvertedResidual(nn.Module):
48 | def __init__(self, inp, oup, stride):
49 | super(InvertedResidual, self).__init__()
50 |
51 | if not (1 <= stride <= 3):
52 | raise ValueError('illegal stride value')
53 | self.stride = stride
54 |
55 | branch_features = oup // 2
56 | assert (self.stride != 1) or (inp == branch_features << 1)
57 |
58 | if self.stride > 1:
59 | self.branch1 = nn.Sequential(
60 | self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
61 | nn.BatchNorm2d(inp),
62 | nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
63 | nn.BatchNorm2d(branch_features),
64 | nn.ReLU(inplace=True),
65 | )
66 | else:
67 | self.branch1 = nn.Sequential()
68 |
69 | self.branch2 = nn.Sequential(
70 | nn.Conv2d(inp if (self.stride > 1) else branch_features,
71 | branch_features, kernel_size=1, stride=1, padding=0, bias=False),
72 | nn.BatchNorm2d(branch_features),
73 | nn.ReLU(inplace=True),
74 | self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
75 | nn.BatchNorm2d(branch_features),
76 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
77 | nn.BatchNorm2d(branch_features),
78 | nn.ReLU(inplace=True),
79 | )
80 |
81 | @staticmethod
82 | def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
83 | return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
84 |
85 | def forward(self, x):
86 | if self.stride == 1:
87 | x1, x2 = x.chunk(2, dim=1)
88 | out = torch.cat((x1, self.branch2(x2)), dim=1)
89 | else:
90 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
91 |
92 | out = channel_shuffle(out, 2)
93 |
94 | return out
95 |
96 |
97 | class ShuffleNetV2(nn.Module):
98 | def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, inverted_residual=InvertedResidual):
99 | super(ShuffleNetV2, self).__init__()
100 |
101 | if len(stages_repeats) != 3:
102 | raise ValueError('expected stages_repeats as list of 3 positive ints')
103 | if len(stages_out_channels) != 5:
104 | raise ValueError('expected stages_out_channels as list of 5 positive ints')
105 | self._stage_out_channels = stages_out_channels
106 |
107 | input_channels = 3
108 | output_channels = self._stage_out_channels[0]
109 | self.conv1 = nn.Sequential(
110 | nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
111 | nn.BatchNorm2d(output_channels),
112 | nn.ReLU(inplace=True),
113 | )
114 | input_channels = output_channels
115 |
116 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
117 |
118 | stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
119 | for name, repeats, output_channels in zip(
120 | stage_names, stages_repeats, self._stage_out_channels[1:]):
121 | seq = [inverted_residual(input_channels, output_channels, 2)]
122 | for i in range(repeats - 1):
123 | seq.append(inverted_residual(output_channels, output_channels, 1))
124 | setattr(self, name, nn.Sequential(*seq))
125 | input_channels = output_channels
126 |
127 | output_channels = self._stage_out_channels[-1]
128 | #self.conv5 = nn.Sequential(
129 | # nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
130 | # nn.BatchNorm2d(output_channels),
131 | # nn.ReLU(inplace=True),
132 | #)
133 |
134 | self.fc = nn.Linear(output_channels, num_classes)
135 |
136 | def get_feat_modules(self):
137 | feat_m = nn.ModuleList([])
138 | feat_m.append(self.conv1)
139 | feat_m.append(self.maxpool)
140 | feat_m.append(self.stage1)
141 | feat_m.append(self.stage2)
142 | feat_m.append(self.stage3)
143 | feat_m.append(self.stage4)
144 | feat_m.append(self.conv5)
145 | feat_m.append(self.fc)
146 | return feat_m
147 |
148 | def forward(self, x, is_feat=False):
149 | hidden_layers = []
150 | x = self.conv1(x)
151 | x = self.maxpool(x)
152 | hidden_layers.append(x)
153 | x = self.stage2(x)
154 | hidden_layers.append(x)
155 | x = self.stage3(x)
156 | hidden_layers.append(x)
157 | x = self.stage4(x)
158 | x = self.conv5(x)
159 | hidden_layers.append(x)
160 | x = x.mean([2, 3]) # globalpool
161 | hidden_layers.append(x)
162 | x = self.fc(x)
163 | if not is_feat:
164 | return x
165 | else:
166 | return hidden_layers, x
167 |
168 |
169 | def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
170 | model = ShuffleNetV2(*args, **kwargs)
171 |
172 | if pretrained:
173 | model_url = model_urls[arch]
174 | if model_url is None:
175 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
176 | else:
177 | state_dict = load_state_dict_from_url(model_url, progress=progress)
178 | model.load_state_dict(state_dict)
179 |
180 | return model
181 |
182 |
183 | def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs):
184 | """
185 | Constructs a ShuffleNetV2 with 0.5x output channels, as described in
186 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
187 | `_.
188 |
189 | Args:
190 | pretrained (bool): If True, returns a model pre-trained on ImageNet
191 | progress (bool): If True, displays a progress bar of the download to stderr
192 | """
193 | return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress,
194 | [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
195 |
196 |
197 |
198 | def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs):
199 | """
200 | Constructs a ShuffleNetV2 with 1.0x output channels, as described in
201 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
202 | `_.
203 |
204 | Args:
205 | pretrained (bool): If True, returns a model pre-trained on ImageNet
206 | progress (bool): If True, displays a progress bar of the download to stderr
207 | """
208 | return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,
209 | [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
210 |
211 |
212 |
213 | def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs):
214 | """
215 | Constructs a ShuffleNetV2 with 1.5x output channels, as described in
216 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
217 | `_.
218 |
219 | Args:
220 | pretrained (bool): If True, returns a model pre-trained on ImageNet
221 | progress (bool): If True, displays a progress bar of the download to stderr
222 | """
223 | return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress,
224 | [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
225 |
226 |
227 |
228 | def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs):
229 | """
230 | Constructs a ShuffleNetV2 with 2.0x output channels, as described in
231 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
232 | `_.
233 |
234 | Args:
235 | pretrained (bool): If True, returns a model pre-trained on ImageNet
236 | progress (bool): If True, displays a progress bar of the download to stderr
237 | """
238 | return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress,
239 | [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
240 |
241 | if __name__ == '__main__':
242 | #x = torch.randn(2, 3, 32, 32)
243 | x = torch.randn(2, 3, 224, 224)
244 |
245 | #net = mobile_half(100)
246 | net = shufflenet_v2_x1_0()
247 |
248 | feats, logit = net(x, is_feat=True)
249 | for f in feats:
250 | print(f.shape, f.min().item())
251 | print(logit.shape)
252 |
253 | num_params_stu = (sum(p.numel() for p in net.parameters())/1000000.0)
254 | print('Total params_stu: {:.3f} M'.format(num_params_stu))
--------------------------------------------------------------------------------
/models/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | class ConvReg(nn.Module):
8 | """Convolutional regression for FitNet (feature-map layer)"""
9 | def __init__(self, s_shape, t_shape):
10 | super(ConvReg, self).__init__()
11 | _, s_C, s_H, s_W = s_shape
12 | _, t_C, t_H, t_W = t_shape
13 | self.s_H = s_H
14 | self.t_H = t_H
15 | if s_H == 2 * t_H:
16 | self.conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1)
17 | elif s_H * 2 == t_H:
18 | self.conv = nn.ConvTranspose2d(s_C, t_C, kernel_size=4, stride=2, padding=1)
19 | elif s_H >= t_H:
20 | self.conv = nn.Conv2d(s_C, t_C, kernel_size=(1+s_H-t_H, 1+s_W-t_W))
21 | else:
22 | self.conv = nn.Conv2d(s_C, t_C, kernel_size=3, padding=1, stride=1)
23 | self.bn = nn.BatchNorm2d(t_C)
24 | self.relu = nn.ReLU(inplace=True)
25 |
26 | def forward(self, x, t):
27 | x = self.conv(x)
28 | if self.s_H == 2 * self.t_H or self.s_H * 2 == self.t_H or self.s_H >= self.t_H:
29 | return self.relu(self.bn(x)), t
30 | else:
31 | return self.relu(self.bn(x)), F.adaptive_avg_pool2d(t, (self.s_H, self.s_H))
32 |
33 | class SelfA(nn.Module):
34 | """Cross-layer Self Attention"""
35 | def __init__(self, feat_dim, s_n, t_n, soft, factor=4):
36 | super(SelfA, self).__init__()
37 |
38 | self.soft = soft
39 | self.s_len = len(s_n)
40 | self.t_len = len(t_n)
41 | self.feat_dim = feat_dim
42 |
43 | # query and key mapping
44 | for i in range(self.s_len):
45 | setattr(self, 'query_'+str(i), MLPEmbed(feat_dim, feat_dim//factor))
46 | for i in range(self.t_len):
47 | setattr(self, 'key_'+str(i), MLPEmbed(feat_dim, feat_dim//factor))
48 |
49 | for i in range(self.s_len):
50 | for j in range(self.t_len):
51 | setattr(self, 'regressor'+str(i)+str(j), Proj(s_n[i], t_n[j]))
52 |
53 | def forward(self, feat_s, feat_t):
54 |
55 | sim_s = list(range(self.s_len))
56 | sim_t = list(range(self.t_len))
57 | bsz = self.feat_dim
58 |
59 | # similarity matrix
60 | for i in range(self.s_len):
61 | sim_temp = feat_s[i].reshape(bsz, -1)
62 | sim_s[i] = torch.matmul(sim_temp, sim_temp.t())
63 | for i in range(self.t_len):
64 | sim_temp = feat_t[i].reshape(bsz, -1)
65 | sim_t[i] = torch.matmul(sim_temp, sim_temp.t())
66 |
67 | # calculate student query
68 | proj_query = self.query_0(sim_s[0])
69 | proj_query = proj_query[:, None, :]
70 | for i in range(1, self.s_len):
71 | temp_proj_query = getattr(self, 'query_'+str(i))(sim_s[i])
72 | proj_query = torch.cat([proj_query, temp_proj_query[:, None, :]], 1)
73 |
74 | # calculate teacher key
75 | proj_key = self.key_0(sim_t[0])
76 | proj_key = proj_key[:, :, None]
77 | for i in range(1, self.t_len):
78 | temp_proj_key = getattr(self, 'key_'+str(i))(sim_t[i])
79 | proj_key = torch.cat([proj_key, temp_proj_key[:, :, None]], 2)
80 |
81 | # attention weight: batch_size X No. stu feature X No.tea feature
82 | energy = torch.bmm(proj_query, proj_key)/self.soft
83 | attention = F.softmax(energy, dim = -1)
84 |
85 | # feature dimension alignment
86 | proj_value_stu = []
87 | value_tea = []
88 | for i in range(self.s_len):
89 | proj_value_stu.append([])
90 | value_tea.append([])
91 | for j in range(self.t_len):
92 | s_H, t_H = feat_s[i].shape[2], feat_t[j].shape[2]
93 | if s_H > t_H:
94 | source = F.adaptive_avg_pool2d(feat_s[i], (t_H, t_H))
95 | target = feat_t[j]
96 | elif s_H <= t_H:
97 | source = feat_s[i]
98 | target = F.adaptive_avg_pool2d(feat_t[j], (s_H, s_H))
99 |
100 | proj_value_stu[i].append(getattr(self, 'regressor'+str(i)+str(j))(source))
101 | value_tea[i].append(target)
102 |
103 | return proj_value_stu, value_tea, attention
104 |
105 | class Proj(nn.Module):
106 | """feature dimension alignment by 1x1, 3x3, 1x1 convolutions"""
107 | def __init__(self, num_input_channels=1024, num_target_channels=128):
108 | super(Proj, self).__init__()
109 | self.num_mid_channel = 2 * num_target_channels
110 |
111 | def conv1x1(in_channels, out_channels, stride=1):
112 | return nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, stride=stride, bias=False)
113 | def conv3x3(in_channels, out_channels, stride=1):
114 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
115 |
116 | self.regressor = nn.Sequential(
117 | conv1x1(num_input_channels, self.num_mid_channel),
118 | nn.BatchNorm2d(self.num_mid_channel),
119 | nn.ReLU(inplace=True),
120 | conv3x3(self.num_mid_channel, self.num_mid_channel),
121 | nn.BatchNorm2d(self.num_mid_channel),
122 | nn.ReLU(inplace=True),
123 | conv1x1(self.num_mid_channel, num_target_channels),
124 | )
125 |
126 | def forward(self, x):
127 | x = self.regressor(x)
128 | return x
129 |
130 | class MLPEmbed(nn.Module):
131 | """non-linear mapping for attention calculation"""
132 | def __init__(self, dim_in=1024, dim_out=128):
133 | super(MLPEmbed, self).__init__()
134 | self.linear1 = nn.Linear(dim_in, 2 * dim_out)
135 | self.relu = nn.ReLU(inplace=True)
136 | self.linear2 = nn.Linear(2 * dim_out, dim_out)
137 | self.l2norm = Normalize(2)
138 | self.regressor = nn.Sequential(
139 | nn.Linear(dim_in, 2 * dim_out),
140 | self.l2norm,
141 | nn.ReLU(inplace=True),
142 | nn.Linear(2 * dim_out, dim_out),
143 | self.l2norm,
144 | )
145 |
146 | def forward(self, x):
147 | x = x.view(x.shape[0], -1)
148 | x = self.relu(self.linear1(x))
149 | x = self.l2norm(self.linear2(x))
150 |
151 | return x
152 |
153 | class Normalize(nn.Module):
154 | """normalization layer"""
155 | def __init__(self, power=2):
156 | super(Normalize, self).__init__()
157 | self.power = power
158 |
159 | def forward(self, x):
160 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
161 | out = x.div(norm)
162 | return out
163 |
164 | class SRRL(nn.Module):
165 | """ICLR-2021: Knowledge Distillation via Softmax Regression Representation Learning"""
166 | def __init__(self, *, s_n, t_n):
167 | super(SRRL, self).__init__()
168 |
169 | def conv1x1(in_channels, out_channels, stride=1):
170 | return nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, stride=stride, bias=False)
171 |
172 | setattr(self, 'transfer', nn.Sequential(
173 | conv1x1(s_n, t_n),
174 | nn.BatchNorm2d(t_n),
175 | nn.ReLU(inplace=True),
176 | ))
177 |
178 | def forward(self, feat_s, cls_t):
179 |
180 | feat_s = feat_s.unsqueeze(-1).unsqueeze(-1)
181 | temp_feat = self.transfer(feat_s)
182 | trans_feat_s = temp_feat.view(temp_feat.size(0), -1)
183 |
184 | pred_feat_s=cls_t(trans_feat_s)
185 |
186 | return trans_feat_s, pred_feat_s
187 |
188 | class SimKD(nn.Module):
189 | """CVPR-2022: Knowledge Distillation with the Reused Teacher Classifier"""
190 | def __init__(self, *, s_n, t_n, factor=2):
191 | super(SimKD, self).__init__()
192 |
193 | self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
194 |
195 | def conv1x1(in_channels, out_channels, stride=1):
196 | return nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, stride=stride, bias=False)
197 | def conv3x3(in_channels, out_channels, stride=1, groups=1):
198 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False, groups=groups)
199 |
200 | # A bottleneck design to reduce extra parameters
201 | setattr(self, 'transfer', nn.Sequential(
202 | conv1x1(s_n, t_n//factor),
203 | nn.BatchNorm2d(t_n//factor),
204 | nn.ReLU(inplace=True),
205 | conv3x3(t_n//factor, t_n//factor),
206 | # depthwise convolution
207 | #conv3x3(t_n//factor, t_n//factor, groups=t_n//factor),
208 | nn.BatchNorm2d(t_n//factor),
209 | nn.ReLU(inplace=True),
210 | conv1x1(t_n//factor, t_n),
211 | nn.BatchNorm2d(t_n),
212 | nn.ReLU(inplace=True),
213 | ))
214 |
215 | def forward(self, feat_s, feat_t, cls_t):
216 |
217 | # Spatial Dimension Alignment
218 | s_H, t_H = feat_s.shape[2], feat_t.shape[2]
219 | if s_H > t_H:
220 | source = F.adaptive_avg_pool2d(feat_s, (t_H, t_H))
221 | target = feat_t
222 | else:
223 | source = feat_s
224 | target = F.adaptive_avg_pool2d(feat_t, (s_H, s_H))
225 |
226 | trans_feat_t=target
227 |
228 | # Channel Alignment
229 | trans_feat_s = getattr(self, 'transfer')(source)
230 |
231 | # Prediction via Teacher Classifier
232 | temp_feat = self.avg_pool(trans_feat_s)
233 | temp_feat = temp_feat.view(temp_feat.size(0), -1)
234 | pred_feat_s = cls_t(temp_feat)
235 |
236 | return trans_feat_s, trans_feat_t, pred_feat_s
--------------------------------------------------------------------------------
/models/vgg.py:
--------------------------------------------------------------------------------
1 | '''
2 | Three FC layers of VGG-ImageNet are replaced with single one,
3 | thus the total layer number should be reduced by two on CIFAR-100.
4 | For example, the actual number of layers for VGG-8 is 6.
5 |
6 | VGG for CIFAR10. FC layers are removed.
7 | (c) YANG, Wei
8 | '''
9 | import math
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | __all__ = [
14 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
15 | 'vgg19_bn', 'vgg19',
16 | ]
17 |
18 | class VGG(nn.Module):
19 |
20 | def __init__(self, cfg, batch_norm=False, num_classes=1000):
21 | super(VGG, self).__init__()
22 | self.block0 = self._make_layers(cfg[0], batch_norm, 3)
23 | self.block1 = self._make_layers(cfg[1], batch_norm, cfg[0][-1])
24 | self.block2 = self._make_layers(cfg[2], batch_norm, cfg[1][-1])
25 | self.block3 = self._make_layers(cfg[3], batch_norm, cfg[2][-1])
26 | self.block4 = self._make_layers(cfg[4], batch_norm, cfg[3][-1])
27 |
28 | self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2)
29 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
30 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
31 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
32 | self.pool4 = nn.AdaptiveAvgPool2d((1, 1))
33 | # self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
34 | self.relu = nn.ReLU(inplace=True)
35 |
36 | self.classifier = nn.Linear(512, num_classes)
37 | self._initialize_weights()
38 |
39 | def get_feat_modules(self):
40 | feat_m = nn.ModuleList([])
41 | feat_m.append(self.block0)
42 | feat_m.append(self.pool0)
43 | feat_m.append(self.block1)
44 | feat_m.append(self.pool1)
45 | feat_m.append(self.block2)
46 | feat_m.append(self.pool2)
47 | feat_m.append(self.block3)
48 | feat_m.append(self.pool3)
49 | feat_m.append(self.block4)
50 | feat_m.append(self.pool4)
51 | feat_m.append(self.classifier)
52 | return feat_m
53 |
54 | def forward(self, x, is_feat=False):
55 | h = x.shape[2]
56 | x = F.relu(self.block0(x))
57 | f0 = x
58 | x = self.pool0(x)
59 | x = self.block1(x)
60 | x = self.relu(x)
61 | f1 = x
62 | x = self.pool1(x)
63 | x = self.block2(x)
64 | x = self.relu(x)
65 | f2 = x
66 | x = self.pool2(x)
67 | x = self.block3(x)
68 | x = self.relu(x)
69 | f3 = x
70 | if h == 64:
71 | x = self.pool3(x)
72 | x = self.block4(x)
73 | x = self.relu(x)
74 | f4 = x
75 | x = self.pool4(x)
76 | x = x.view(x.size(0), -1)
77 | f5 = x
78 | x = self.classifier(x)
79 |
80 | if is_feat:
81 | return [f0, f1, f2, f3, f4, f5], x
82 | else:
83 | return x
84 |
85 | @staticmethod
86 | def _make_layers(cfg, batch_norm=False, in_channels=3):
87 | layers = []
88 | for v in cfg:
89 | if v == 'M':
90 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
91 | else:
92 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
93 | if batch_norm:
94 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
95 | else:
96 | layers += [conv2d, nn.ReLU(inplace=True)]
97 | in_channels = v
98 | layers = layers[:-1]
99 | return nn.Sequential(*layers)
100 |
101 | def _initialize_weights(self):
102 | for m in self.modules():
103 | if isinstance(m, nn.Conv2d):
104 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
105 | m.weight.data.normal_(0, math.sqrt(2. / n))
106 | if m.bias is not None:
107 | m.bias.data.zero_()
108 | elif isinstance(m, nn.BatchNorm2d):
109 | m.weight.data.fill_(1)
110 | m.bias.data.zero_()
111 | elif isinstance(m, nn.Linear):
112 | n = m.weight.size(1)
113 | m.weight.data.normal_(0, 0.01)
114 | m.bias.data.zero_()
115 |
116 | cfg = {
117 | 'A': [[64], [128], [256, 256], [512, 512], [512, 512]],
118 | 'B': [[64, 64], [128, 128], [256, 256], [512, 512], [512, 512]],
119 | 'D': [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]],
120 | 'E': [[64, 64], [128, 128], [256, 256, 256, 256], [512, 512, 512, 512], [512, 512, 512, 512]],
121 | 'S': [[64], [128], [256], [512], [512]],
122 | }
123 |
124 | def vgg8(**kwargs):
125 | """VGG 8-layer model (configuration "S")"""
126 | model = VGG(cfg['S'], **kwargs)
127 | return model
128 |
129 |
130 | def vgg8_bn(**kwargs):
131 | """VGG 8-layer model (configuration "S")"""
132 | model = VGG(cfg['S'], batch_norm=True, **kwargs)
133 | return model
134 |
135 |
136 | def vgg11(**kwargs):
137 | """VGG 11-layer model (configuration "A")"""
138 | model = VGG(cfg['A'], **kwargs)
139 | return model
140 |
141 |
142 | def vgg11_bn(**kwargs):
143 | """VGG 11-layer model (configuration "A") with batch normalization"""
144 | model = VGG(cfg['A'], batch_norm=True, **kwargs)
145 | return model
146 |
147 |
148 | def vgg13(**kwargs):
149 | """VGG 13-layer model (configuration "B")"""
150 | model = VGG(cfg['B'], **kwargs)
151 | return model
152 |
153 |
154 | def vgg13_bn(**kwargs):
155 | """VGG 13-layer model (configuration "B") with batch normalization"""
156 | model = VGG(cfg['B'], batch_norm=True, **kwargs)
157 | return model
158 |
159 |
160 | def vgg16(**kwargs):
161 | """VGG 16-layer model (configuration "D")"""
162 | model = VGG(cfg['D'], **kwargs)
163 | return model
164 |
165 |
166 | def vgg16_bn(**kwargs):
167 | """VGG 16-layer model (configuration "D") with batch normalization"""
168 | model = VGG(cfg['D'], batch_norm=True, **kwargs)
169 | return model
170 |
171 |
172 | def vgg19(**kwargs):
173 | """VGG 19-layer model (configuration "E")"""
174 | model = VGG(cfg['E'], **kwargs)
175 | return model
176 |
177 |
178 | def vgg19_bn(**kwargs):
179 | """VGG 19-layer model (configuration 'E') with batch normalization"""
180 | model = VGG(cfg['E'], batch_norm=True, **kwargs)
181 | return model
182 |
183 |
184 | if __name__ == '__main__':
185 | import torch
186 |
187 | x = torch.randn(2, 3, 32, 32)
188 | net = vgg13_bn(num_classes=100)
189 | feats, logit = net(x, is_feat=True)
190 |
191 | for f in feats:
192 | print(f.shape, f.min().item())
193 | print(logit.shape)
194 |
195 | num_params_stu = (sum(p.numel() for p in net.parameters())/1000000.0)
196 | print('Total params_stu: {:.3f} M'.format(num_params_stu))
197 |
--------------------------------------------------------------------------------
/scripts/run_distill.sh:
--------------------------------------------------------------------------------
1 | # sample scripts for running various knowledge distillation approaches
2 | # we use resnet32x4 and resnet8x4 as an example
3 |
4 | # CIFAR
5 | # KD
6 | python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill kd --model_s resnet8x4 -c 1 -d 1 -b 0 --trial 0 --gpu_id 0
7 | # FitNet
8 | python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill hint --model_s resnet8x4 -c 1 -d 1 -b 100 --trial 0 --gpu_id 0
9 | # AT
10 | python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill attention --model_s resnet8x4 -c 1 -d 1 -b 1000 --trial 0 --gpu_id 0
11 | # SP
12 | python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill similarity --model_s resnet8x4 -c 1 -d 1 -b 3000 --trial 0 --gpu_id 0
13 | # VID
14 | python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill vid --model_s resnet8x4 -c 1 -d 1 -b 1 --trial 0 --gpu_id 0
15 | # CRD
16 | python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill crd --model_s resnet8x4 -c 1 -d 1 -b 0.8 --trial 0 --gpu_id 0
17 | # SemCKD
18 | python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill semckd --model_s resnet8x4 -c 1 -d 1 -b 400 --trial 0 --gpu_id 0
19 | # SRRL
20 | python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill srrl --model_s resnet8x4 -c 1 -d 1 -b 1 --trial 0 --gpu_id 0
21 | # SimKD
22 | python train_student.py --path_t ./save/teachers/models/resnet32x4_vanilla/resnet32x4_best.pth --distill simkd --model_s resnet8x4 -c 0 -d 0 -b 1 --trial 0 --gpu_id 0
23 |
24 | # ImageNets
25 | python train_student.py --path_t './save/teachers/models/ResNet50_vanilla/ResNet50_best.pth' --batch_size 256 --epochs 120 --dataset imagenet --model_s ResNet18 --distill simkd -c 0 -d 0 -b 1 --learning_rate 0.1 --lr_decay_epochs 30,60,90 --weight_decay 1e-4 --num_workers 32 --gpu_id 0,1,2,3 --dist-url tcp://127.0.0.1:23344 --multiprocessing-distributed --dali gpu --trial 0
26 |
--------------------------------------------------------------------------------
/scripts/run_vanilla.sh:
--------------------------------------------------------------------------------
1 | # sample scripts for training vanilla teacher/student models
2 |
3 | # CIFAR
4 | python train_teacher.py --model resnet8x4 --trial 0 --gpu_id 0
5 |
6 | python train_teacher.py --model resnet32x4 --trial 0 --gpu_id 0
7 |
8 | python train_teacher.py --model resnet110 --trial 0 --gpu_id 0
9 |
10 | python train_teacher.py --model resnet116 --trial 0 --gpu_id 0
11 |
12 | python train_teacher.py --model resnet110x2 --trial 0 --gpu_id 0
13 |
14 | python train_teacher.py --model vgg8 --trial 0 --gpu_id 0
15 |
16 | python train_teacher.py --model vgg13 --trial 0 --gpu_id 0
17 |
18 | python train_teacher.py --model ShuffleV1 --trial 0 --gpu_id 0
19 |
20 | python train_teacher.py --model ShuffleV2 --trial 0 --gpu_id 0
21 |
22 | python train_teacher.py --model ShuffleV2_1_5 --trial 0 --gpu_id 0
23 |
24 | python train_teacher.py --model MobileNetV2 --trial 0 --gpu_id 0
25 |
26 | python train_teacher.py --model MobileNetV2_1_0 --trial 0 --gpu_id 0
27 |
28 | # WRN-40-1
29 | python train_teacher.py --model resnet38 --trial 0 --gpu_id 0
30 | # WRN-40-2
31 | python train_teacher.py --model resnet38x2 --trial 0 --gpu_id 0
32 | # WRN-16-2
33 | python train_teacher.py --model resnet14x2 --trial 0 --gpu_id 0
34 | # WRN-40-4
35 | python train_teacher.py --model resnet38x4 --trial 0 --gpu_id 0
36 | # WRN-16-4
37 | python train_teacher.py --model resnet14x4 --trial 0 --gpu_id 0
38 |
39 | # ImageNet
40 | python train_teacher.py --batch_size 256 --epochs 120 --dataset imagenet --model ResNet18 --learning_rate 0.1 --lr_decay_epochs 30,60,90 --weight_decay 1e-4 --num_workers 32 --gpu_id 0,1,2,3 --dist-url tcp://127.0.0.1:23333 --multiprocessing-distributed --dali gpu --trial 0
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/train_student.py:
--------------------------------------------------------------------------------
1 | """
2 | the general training framework
3 | """
4 |
5 | from __future__ import print_function
6 |
7 | import os
8 | import re
9 | import argparse
10 | import time
11 |
12 | import numpy
13 | import torch
14 | import torch.optim as optim
15 | import torch.multiprocessing as mp
16 | import torch.distributed as dist
17 | import torch.nn as nn
18 | import torch.backends.cudnn as cudnn
19 | import tensorboard_logger as tb_logger
20 |
21 | from models import model_dict
22 | from models.util import ConvReg, SelfA, SRRL, SimKD
23 |
24 | from dataset.cifar100 import get_cifar100_dataloaders, get_cifar100_dataloaders_sample
25 | from dataset.imagenet import get_imagenet_dataloader, get_dataloader_sample
26 | from dataset.imagenet_dali import get_dali_data_loader
27 |
28 | from helper.loops import train_distill as train, validate_vanilla, validate_distill
29 | from helper.util import save_dict_to_json, reduce_tensor, adjust_learning_rate
30 |
31 | from crd.criterion import CRDLoss
32 | from distiller_zoo import DistillKL, HintLoss, Attention, Similarity, VIDLoss, SemCKDLoss
33 |
34 | split_symbol = '~' if os.name == 'nt' else ':'
35 |
36 | def parse_option():
37 |
38 | parser = argparse.ArgumentParser('argument for training')
39 |
40 | # basic
41 | parser.add_argument('--print_freq', type=int, default=200, help='print frequency')
42 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
43 | parser.add_argument('--num_workers', type=int, default=8, help='num of workers to use')
44 | parser.add_argument('--epochs', type=int, default=240, help='number of training epochs')
45 | parser.add_argument('--gpu_id', type=str, default='0', help='id(s) for CUDA_VISIBLE_DEVICES')
46 |
47 | # optimization
48 | parser.add_argument('--learning_rate', type=float, default=0.05, help='learning rate')
49 | parser.add_argument('--lr_decay_epochs', type=str, default='150,180,210', help='where to decay lr, can be a list')
50 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate')
51 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay')
52 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
53 |
54 | # dataset and model
55 | parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar100', 'imagenet'], help='dataset')
56 | parser.add_argument('--model_s', type=str, default='resnet8x4')
57 | parser.add_argument('--path_t', type=str, default=None, help='teacher model snapshot')
58 |
59 | # distillation
60 | parser.add_argument('--trial', type=str, default='1', help='trial id')
61 | parser.add_argument('--kd_T', type=float, default=4, help='temperature for KD distillation')
62 | parser.add_argument('--distill', type=str, default='kd', choices=['kd', 'hint', 'attention', 'similarity', 'vid',
63 | 'crd', 'semckd','srrl', 'simkd'])
64 | parser.add_argument('-c', '--cls', type=float, default=1.0, help='weight for classification')
65 | parser.add_argument('-d', '--div', type=float, default=1.0, help='weight balance for KD')
66 | parser.add_argument('-b', '--beta', type=float, default=0.0, help='weight balance for other losses')
67 | parser.add_argument('-f', '--factor', type=int, default=2, help='factor size of SimKD')
68 | parser.add_argument('-s', '--soft', type=float, default=1.0, help='attention scale of SemCKD')
69 |
70 | # hint layer
71 | parser.add_argument('--hint_layer', default=1, type=int, choices=[0, 1, 2, 3, 4])
72 |
73 | # NCE distillation
74 | parser.add_argument('--feat_dim', default=128, type=int, help='feature dimension')
75 | parser.add_argument('--mode', default='exact', type=str, choices=['exact', 'relax'])
76 | parser.add_argument('--nce_k', default=16384, type=int, help='number of negative samples for NCE')
77 | parser.add_argument('--nce_t', default=0.07, type=float, help='temperature parameter for softmax')
78 | parser.add_argument('--nce_m', default=0.5, type=float, help='momentum for non-parametric updates')
79 |
80 | # multiprocessing
81 | parser.add_argument('--dali', type=str, choices=['cpu', 'gpu'], default=None)
82 | parser.add_argument('--multiprocessing-distributed', action='store_true',
83 | help='Use multi-processing distributed training to launch '
84 | 'N processes per node, which has N GPUs. This is the '
85 | 'fastest way to use PyTorch for either single node or '
86 | 'multi node data parallel training')
87 | parser.add_argument('--dist-url', default='tcp://127.0.0.1:23451', type=str,
88 | help='url used to set up distributed training')
89 | parser.add_argument('--deterministic', action='store_true', help='Make results reproducible')
90 | parser.add_argument('--skip-validation', action='store_true', help='Skip validation of teacher')
91 |
92 | opt = parser.parse_args()
93 |
94 | # set different learning rates for these MobileNet/ShuffleNet models
95 | if opt.model_s in ['MobileNetV2', 'MobileNetV2_1_0', 'ShuffleV1', 'ShuffleV2', 'ShuffleV2_1_5']:
96 | opt.learning_rate = 0.01
97 |
98 | # set the path of model and tensorboard
99 | opt.model_path = './save/students/models'
100 | opt.tb_path = './save/students/tensorboard'
101 |
102 | iterations = opt.lr_decay_epochs.split(',')
103 | opt.lr_decay_epochs = list([])
104 | for it in iterations:
105 | opt.lr_decay_epochs.append(int(it))
106 |
107 | opt.model_t = get_teacher_name(opt.path_t)
108 |
109 | model_name_template = split_symbol.join(['S', '{}_T', '{}_{}_{}_r', '{}_a', '{}_b', '{}_{}'])
110 | opt.model_name = model_name_template.format(opt.model_s, opt.model_t, opt.dataset, opt.distill,
111 | opt.cls, opt.div, opt.beta, opt.trial)
112 |
113 | if opt.dali is not None:
114 | opt.model_name += '_dali:' + opt.dali
115 |
116 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
117 | if not os.path.isdir(opt.tb_folder):
118 | os.makedirs(opt.tb_folder)
119 |
120 | opt.save_folder = os.path.join(opt.model_path, opt.model_name)
121 | if not os.path.isdir(opt.save_folder):
122 | os.makedirs(opt.save_folder)
123 |
124 | return opt
125 |
126 | def get_teacher_name(model_path):
127 | """parse teacher name"""
128 | directory = model_path.split('/')[-2]
129 | pattern = ''.join(['S', split_symbol, '(.+)', '_T', split_symbol])
130 | name_match = re.match(pattern, directory)
131 | if name_match:
132 | return name_match[1]
133 | segments = directory.split('_')
134 | if segments[0] == 'wrn':
135 | return segments[0] + '_' + segments[1] + '_' + segments[2]
136 | return segments[0]
137 |
138 |
139 | def load_teacher(model_path, n_cls, gpu=None, opt=None):
140 | print('==> loading teacher model')
141 | model_t = get_teacher_name(model_path)
142 | model = model_dict[model_t](num_classes=n_cls)
143 | map_location = None if gpu is None else {'cuda:0': 'cuda:%d' % (gpu if opt.multiprocessing_distributed else 0)}
144 | model.load_state_dict(torch.load(model_path, map_location=map_location)['model'])
145 | print('==> done')
146 | return model
147 |
148 |
149 | best_acc = 0
150 | total_time = time.time()
151 | def main():
152 |
153 | opt = parse_option()
154 |
155 | # ASSIGN CUDA_ID
156 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
157 |
158 | ngpus_per_node = torch.cuda.device_count()
159 | opt.ngpus_per_node = ngpus_per_node
160 | if opt.multiprocessing_distributed:
161 | # Since we have ngpus_per_node processes per node, the total world_size
162 | # needs to be adjusted accordingly
163 | world_size = 1
164 | opt.world_size = ngpus_per_node * world_size
165 | # Use torch.multiprocessing.spawn to launch distributed processes: the
166 | # main_worker process function
167 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt))
168 | else:
169 | main_worker(None if ngpus_per_node > 1 else opt.gpu_id, ngpus_per_node, opt)
170 |
171 | def main_worker(gpu, ngpus_per_node, opt):
172 | global best_acc, total_time
173 | opt.gpu = int(gpu)
174 | opt.gpu_id = int(gpu)
175 |
176 | if opt.gpu is not None:
177 | print("Use GPU: {} for training".format(opt.gpu))
178 |
179 | if opt.multiprocessing_distributed:
180 | # Only one node now.
181 | opt.rank = gpu
182 | dist_backend = 'nccl'
183 | dist.init_process_group(backend=dist_backend, init_method=opt.dist_url,
184 | world_size=opt.world_size, rank=opt.rank)
185 | opt.batch_size = int(opt.batch_size / ngpus_per_node)
186 | opt.num_workers = int((opt.num_workers + ngpus_per_node - 1) / ngpus_per_node)
187 |
188 | if opt.deterministic:
189 | torch.manual_seed(12345)
190 | cudnn.deterministic = True
191 | cudnn.benchmark = False
192 | numpy.random.seed(12345)
193 |
194 | # model
195 | n_cls = {
196 | 'cifar100': 100,
197 | 'imagenet': 1000,
198 | }.get(opt.dataset, None)
199 |
200 | model_t = load_teacher(opt.path_t, n_cls, opt.gpu, opt)
201 | try:
202 | model_s = model_dict[opt.model_s](num_classes=n_cls)
203 | except KeyError:
204 | print("This model is not supported.")
205 |
206 | if opt.dataset == 'cifar100':
207 | data = torch.randn(2, 3, 32, 32)
208 | elif opt.dataset == 'imagenet':
209 | data = torch.randn(2, 3, 224, 224)
210 |
211 | model_t.eval()
212 | model_s.eval()
213 | feat_t, _ = model_t(data, is_feat=True)
214 | feat_s, _ = model_s(data, is_feat=True)
215 |
216 | module_list = nn.ModuleList([])
217 | module_list.append(model_s)
218 | trainable_list = nn.ModuleList([])
219 | trainable_list.append(model_s)
220 |
221 | criterion_cls = nn.CrossEntropyLoss()
222 | criterion_div = DistillKL(opt.kd_T)
223 | if opt.distill == 'kd':
224 | criterion_kd = DistillKL(opt.kd_T)
225 | elif opt.distill == 'hint':
226 | criterion_kd = HintLoss()
227 | regress_s = ConvReg(feat_s[opt.hint_layer].shape, feat_t[opt.hint_layer].shape)
228 | module_list.append(regress_s)
229 | trainable_list.append(regress_s)
230 | elif opt.distill == 'attention':
231 | criterion_kd = Attention()
232 | elif opt.distill == 'similarity':
233 | criterion_kd = Similarity()
234 | elif opt.distill == 'vid':
235 | s_n = [f.shape[1] for f in feat_s[1:-1]]
236 | t_n = [f.shape[1] for f in feat_t[1:-1]]
237 | criterion_kd = nn.ModuleList(
238 | [VIDLoss(s, t, t) for s, t in zip(s_n, t_n)]
239 | )
240 | # add this as some parameters in VIDLoss need to be updated
241 | trainable_list.append(criterion_kd)
242 | elif opt.distill == 'crd':
243 | opt.s_dim = feat_s[-1].shape[1]
244 | opt.t_dim = feat_t[-1].shape[1]
245 | if opt.dataset == 'cifar100':
246 | opt.n_data = 50000
247 | else:
248 | opt.n_data = 1281167
249 | criterion_kd = CRDLoss(opt)
250 | module_list.append(criterion_kd.embed_s)
251 | module_list.append(criterion_kd.embed_t)
252 | trainable_list.append(criterion_kd.embed_s)
253 | trainable_list.append(criterion_kd.embed_t)
254 | elif opt.distill == 'semckd':
255 | s_n = [f.shape[1] for f in feat_s[1:-1]]
256 | t_n = [f.shape[1] for f in feat_t[1:-1]]
257 | criterion_kd = SemCKDLoss()
258 | self_attention = SelfA(opt.batch_size, s_n, t_n, opt.soft)
259 | module_list.append(self_attention)
260 | trainable_list.append(self_attention)
261 | elif opt.distill == 'srrl':
262 | s_n = feat_s[-1].shape[1]
263 | t_n = feat_t[-1].shape[1]
264 | model_fmsr = SRRL(s_n= s_n, t_n=t_n)
265 | criterion_kd = nn.MSELoss()
266 | module_list.append(model_fmsr)
267 | trainable_list.append(model_fmsr)
268 | elif opt.distill == 'simkd':
269 | s_n = feat_s[-2].shape[1]
270 | t_n = feat_t[-2].shape[1]
271 | model_simkd = SimKD(s_n= s_n, t_n=t_n, factor=opt.factor)
272 | criterion_kd = nn.MSELoss()
273 | module_list.append(model_simkd)
274 | trainable_list.append(model_simkd)
275 | else:
276 | raise NotImplementedError(opt.distill)
277 |
278 | criterion_list = nn.ModuleList([])
279 | criterion_list.append(criterion_cls) # classification loss
280 | criterion_list.append(criterion_div) # KL divergence loss, original knowledge distillation
281 | criterion_list.append(criterion_kd) # other knowledge distillation loss
282 |
283 | module_list.append(model_t)
284 |
285 | optimizer = optim.SGD(trainable_list.parameters(),
286 | lr=opt.learning_rate,
287 | momentum=opt.momentum,
288 | weight_decay=opt.weight_decay)
289 |
290 | if torch.cuda.is_available():
291 | # For multiprocessing distributed, DistributedDataParallel constructor
292 | # should always set the single device scope, otherwise,
293 | # DistributedDataParallel will use all available devices.
294 | if opt.multiprocessing_distributed:
295 | if opt.gpu is not None:
296 | torch.cuda.set_device(opt.gpu)
297 | module_list.cuda(opt.gpu)
298 | distributed_modules = []
299 | for module in module_list:
300 | DDP = torch.nn.parallel.DistributedDataParallel
301 | distributed_modules.append(DDP(module, device_ids=[opt.gpu]))
302 | module_list = distributed_modules
303 | criterion_list.cuda(opt.gpu)
304 | else:
305 | print('multiprocessing_distributed must be with a specifiec gpu id')
306 | else:
307 | criterion_list.cuda()
308 | module_list.cuda()
309 | if not opt.deterministic:
310 | cudnn.benchmark = True
311 |
312 | # dataloader
313 | if opt.dataset == 'cifar100':
314 | if opt.distill in ['crd']:
315 | train_loader, val_loader, n_data = get_cifar100_dataloaders_sample(batch_size=opt.batch_size,
316 | num_workers=opt.num_workers,
317 | k=opt.nce_k,
318 | mode=opt.mode)
319 | else:
320 | train_loader, val_loader = get_cifar100_dataloaders(batch_size=opt.batch_size,
321 | num_workers=opt.num_workers)
322 | elif opt.dataset == 'imagenet':
323 | if opt.dali is None:
324 | if opt.distill in ['crd']:
325 | train_loader, val_loader, n_data, _, train_sampler = get_dataloader_sample(dataset=opt.dataset, batch_size=opt.batch_size,
326 | num_workers=opt.num_workers,
327 | is_sample=True,
328 | k=opt.nce_k,
329 | multiprocessing_distributed=opt.multiprocessing_distributed)
330 | else:
331 | train_loader, val_loader, train_sampler = get_imagenet_dataloader(dataset=opt.dataset, batch_size=opt.batch_size,
332 | num_workers=opt.num_workers,
333 | multiprocessing_distributed=opt.multiprocessing_distributed)
334 | else:
335 | train_loader, val_loader = get_dali_data_loader(opt)
336 | else:
337 | raise NotImplementedError(opt.dataset)
338 |
339 | if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0:
340 | logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)
341 |
342 | if not opt.skip_validation:
343 | # validate teacher accuracy
344 | teacher_acc, _, _ = validate_vanilla(val_loader, model_t, criterion_cls, opt)
345 |
346 | if opt.dali is not None:
347 | val_loader.reset()
348 |
349 | if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0:
350 | print('teacher accuracy: ', teacher_acc)
351 | else:
352 | print('Skipping teacher validation.')
353 |
354 | # routine
355 | for epoch in range(1, opt.epochs + 1):
356 | torch.cuda.empty_cache()
357 | if opt.multiprocessing_distributed:
358 | if opt.dali is None:
359 | train_sampler.set_epoch(epoch)
360 |
361 | adjust_learning_rate(epoch, opt, optimizer)
362 | print("==> training...")
363 |
364 | time1 = time.time()
365 | train_acc, train_acc_top5, train_loss = train(epoch, train_loader, module_list, criterion_list, optimizer, opt)
366 | time2 = time.time()
367 |
368 | if opt.multiprocessing_distributed:
369 | metrics = torch.tensor([train_acc, train_acc_top5, train_loss]).cuda(opt.gpu, non_blocking=True)
370 | reduced = reduce_tensor(metrics, opt.world_size if 'world_size' in opt else 1)
371 | train_acc, train_acc_top5, train_loss = reduced.tolist()
372 |
373 | if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0:
374 | print(' * Epoch {}, GPU {}, Acc@1 {:.3f}, Acc@5 {:.3f}, Time {:.2f}'.format(epoch, opt.gpu, train_acc, train_acc_top5, time2 - time1))
375 |
376 | logger.log_value('train_acc', train_acc, epoch)
377 | logger.log_value('train_loss', train_loss, epoch)
378 |
379 | print('GPU %d validating' % (opt.gpu))
380 | test_acc, test_acc_top5, test_loss = validate_distill(val_loader, module_list, criterion_cls, opt)
381 |
382 | if opt.dali is not None:
383 | train_loader.reset()
384 | val_loader.reset()
385 |
386 | if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0:
387 | print(' ** Acc@1 {:.3f}, Acc@5 {:.3f}'.format(test_acc, test_acc_top5))
388 |
389 | logger.log_value('test_acc', test_acc, epoch)
390 | logger.log_value('test_loss', test_loss, epoch)
391 | logger.log_value('test_acc_top5', test_acc_top5, epoch)
392 |
393 | # save the best model
394 | if test_acc > best_acc:
395 | best_acc = test_acc
396 | state = {
397 | 'epoch': epoch,
398 | 'model': model_s.state_dict(),
399 | 'best_acc': best_acc,
400 | }
401 | if opt.distill == 'simkd':
402 | state['proj'] = trainable_list[-1].state_dict()
403 | save_file = os.path.join(opt.save_folder, '{}_best.pth'.format(opt.model_s))
404 |
405 | test_merics = {'test_loss': test_loss,
406 | 'test_acc': test_acc,
407 | 'test_acc_top5': test_acc_top5,
408 | 'epoch': epoch}
409 |
410 | save_dict_to_json(test_merics, os.path.join(opt.save_folder, "test_best_metrics.json"))
411 | print('saving the best model!')
412 | torch.save(state, save_file)
413 |
414 | if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0:
415 | # This best accuracy is only for printing purpose.
416 | print('best accuracy:', best_acc)
417 |
418 | # save parameters
419 | save_state = {k: v for k, v in opt._get_kwargs()}
420 | # No. parameters(M)
421 | num_params = (sum(p.numel() for p in model_s.parameters())/1000000.0)
422 | save_state['Total params'] = num_params
423 | save_state['Total time'] = (time.time() - total_time)/3600.0
424 | params_json_path = os.path.join(opt.save_folder, "parameters.json")
425 | save_dict_to_json(save_state, params_json_path)
426 |
427 | if __name__ == '__main__':
428 | main()
429 |
--------------------------------------------------------------------------------
/train_teacher.py:
--------------------------------------------------------------------------------
1 | """
2 | Training a single model (student or teacher)
3 | """
4 |
5 | import os
6 | import argparse
7 | import time
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.optim as optim
12 | import torch.distributed as dist
13 | import torch.multiprocessing as mp
14 | import torch.backends.cudnn as cudnn
15 | import tensorboard_logger as tb_logger
16 |
17 | from models import model_dict
18 | from dataset.cifar100 import get_cifar100_dataloaders
19 | from dataset.imagenet import get_imagenet_dataloader
20 | from dataset.imagenet_dali import get_dali_data_loader
21 | from helper.util import save_dict_to_json, reduce_tensor, adjust_learning_rate
22 | from helper.loops import train_vanilla as train, validate_vanilla
23 |
24 | def parse_option():
25 |
26 | parser = argparse.ArgumentParser('argument for training')
27 |
28 | # baisc
29 | parser.add_argument('--print_freq', type=int, default=200, help='print frequency')
30 | parser.add_argument('--batch_size', type=int, default=64, help='batch_size')
31 | parser.add_argument('--num_workers', type=int, default=8, help='num_workers')
32 | parser.add_argument('--epochs', type=int, default=240, help='number of training epochs')
33 | parser.add_argument('--gpu_id', type=str, default='0', help='id(s) for CUDA_VISIBLE_DEVICES')
34 |
35 | # optimization
36 | parser.add_argument('--learning_rate', type=float, default=0.05, help='learning rate')
37 | parser.add_argument('--lr_decay_epochs', type=str, default='150,180,210', help='where to decay lr, can be a list')
38 | parser.add_argument('--lr_decay_rate', type=float, default=0.1, help='decay rate for learning rate')
39 | parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay')
40 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
41 |
42 | # dataset
43 | parser.add_argument('--model', type=str, default='resnet32x4')
44 | parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar100', 'imagenet'], help='dataset')
45 | parser.add_argument('-t', '--trial', type=str, default='0', help='the experiment id')
46 | parser.add_argument('--dali', type=str, choices=['cpu', 'gpu'], default=None)
47 |
48 | # multiprocessing
49 | parser.add_argument('--multiprocessing-distributed', action='store_true',
50 | help='Use multi-processing distributed training to launch '
51 | 'N processes per node, which has N GPUs. This is the '
52 | 'fastest way to use PyTorch for either single node or '
53 | 'multi node data parallel training')
54 | parser.add_argument('--dist-url', default='tcp://127.0.0.1:23451', type=str,
55 | help='url used to set up distributed training')
56 |
57 | opt = parser.parse_args()
58 |
59 | # set different learning rates for these MobileNet/ShuffleNet models
60 | if opt.model in ['MobileNetV2', 'MobileNetV2_1_0', 'ShuffleV1', 'ShuffleV2', 'ShuffleV2_1_5']:
61 | opt.learning_rate = 0.01
62 |
63 | # set the path of model and tensorboard
64 | opt.model_path = './save/teachers/models'
65 | opt.tb_path = './save/teachers/tensorboard'
66 |
67 | iterations = opt.lr_decay_epochs.split(',')
68 | opt.lr_decay_epochs = list([])
69 | for it in iterations:
70 | opt.lr_decay_epochs.append(int(it))
71 |
72 | # set the model name
73 | opt.model_name = '{}_vanilla_{}_trial_{}'.format(opt.model, opt.dataset, opt.trial)
74 | if opt.dali is not None:
75 | opt.model_name += '_dali:' + opt.dali
76 |
77 | opt.tb_folder = os.path.join(opt.tb_path, opt.model_name)
78 | if not os.path.isdir(opt.tb_folder):
79 | os.makedirs(opt.tb_folder)
80 |
81 | opt.save_folder = os.path.join(opt.model_path, opt.model_name)
82 | if not os.path.isdir(opt.save_folder):
83 | os.makedirs(opt.save_folder)
84 |
85 | return opt
86 |
87 | best_acc = 0
88 | total_time = time.time()
89 | def main():
90 | opt = parse_option()
91 |
92 | # ASSIGN CUDA_ID
93 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
94 |
95 | ngpus_per_node = torch.cuda.device_count()
96 | opt.ngpus_per_node = ngpus_per_node
97 | if opt.multiprocessing_distributed:
98 | # Since we have ngpus_per_node processes per node, the total world_size
99 | # needs to be adjusted accordingly
100 | world_size = 1
101 | opt.world_size = ngpus_per_node * world_size
102 | # Use torch.multiprocessing.spawn to launch distributed processes: the
103 | # main_worker process function
104 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt))
105 | else:
106 | main_worker(None if ngpus_per_node > 1 else opt.gpu_id, ngpus_per_node, opt)
107 |
108 | def main_worker(gpu, ngpus_per_node, opt):
109 | global best_acc, total_time
110 | opt.gpu = int(gpu)
111 | opt.gpu_id = int(gpu)
112 |
113 | if opt.gpu is not None:
114 | print("Use GPU: {} for training".format(opt.gpu))
115 |
116 | if opt.multiprocessing_distributed:
117 | # Only one node now.
118 | opt.rank = int(gpu)
119 | dist_backend = 'nccl'
120 | dist.init_process_group(backend=dist_backend, init_method=opt.dist_url,
121 | world_size=opt.world_size, rank=opt.rank)
122 |
123 | # model
124 | n_cls = {
125 | 'cifar100': 100,
126 | 'imagenet': 1000,
127 | }.get(opt.dataset, None)
128 |
129 | try:
130 | model = model_dict[opt.model](num_classes=n_cls)
131 | except KeyError:
132 | print("This model is not supported.")
133 |
134 | # optimizer
135 | optimizer = optim.SGD(model.parameters(),
136 | lr=opt.learning_rate,
137 | momentum=opt.momentum,
138 | weight_decay=opt.weight_decay)
139 | criterion = nn.CrossEntropyLoss()
140 |
141 | if torch.cuda.is_available():
142 | # For multiprocessing distributed, DistributedDataParallel constructor
143 | # should always set the single device scope, otherwise,
144 | # DistributedDataParallel will use all available devices.
145 | if opt.multiprocessing_distributed:
146 | if opt.gpu is not None:
147 | torch.cuda.set_device(opt.gpu)
148 | model = model.cuda(opt.gpu)
149 | criterion = criterion.cuda(opt.gpu)
150 | # When using a single GPU per process and per
151 | # DistributedDataParallel, we need to divide the batch size
152 | # ourselves based on the total number of GPUs we have
153 | opt.batch_size = int(opt.batch_size / ngpus_per_node)
154 | opt.num_workers = int((opt.num_workers + ngpus_per_node - 1) / ngpus_per_node)
155 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[opt.gpu])
156 | else:
157 | print('multiprocessing_distributed must be with a specifiec gpu id')
158 | else:
159 | criterion = criterion.cuda()
160 | if torch.cuda.device_count() > 1:
161 | model = nn.DataParallel(model).cuda()
162 | else:
163 | model = model.cuda()
164 |
165 |
166 | cudnn.benchmark = True
167 |
168 | # dataloader
169 | if opt.dataset == 'cifar100':
170 | train_loader, val_loader = get_cifar100_dataloaders(batch_size=opt.batch_size, num_workers=opt.num_workers)
171 | elif opt.dataset == 'imagenet':
172 | if opt.dali is None:
173 | train_loader, val_loader, train_sampler = get_imagenet_dataloader(
174 | dataset = opt.dataset,
175 | batch_size=opt.batch_size, num_workers=opt.num_workers,
176 | multiprocessing_distributed=opt.multiprocessing_distributed)
177 | else:
178 | train_loader, val_loader = get_dali_data_loader(opt)
179 | else:
180 | raise NotImplementedError(opt.dataset)
181 |
182 | # tensorboard
183 | if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0:
184 | logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2)
185 |
186 | # routine
187 | for epoch in range(1, opt.epochs + 1):
188 | if opt.multiprocessing_distributed:
189 | if opt.dali is None:
190 | train_sampler.set_epoch(epoch)
191 | # No test_sampler because epoch is random seed, not needed in sequential testing.
192 |
193 | adjust_learning_rate(epoch, opt, optimizer)
194 | print("==> training...")
195 |
196 | time1 = time.time()
197 | train_acc, train_acc_top5, train_loss = train(epoch, train_loader, model, criterion, optimizer, opt)
198 | time2 = time.time()
199 |
200 | if opt.multiprocessing_distributed:
201 | metrics = torch.tensor([train_acc, train_acc_top5, train_loss]).cuda(opt.gpu, non_blocking=True)
202 | reduced = reduce_tensor(metrics, opt.world_size if 'world_size' in opt else 1)
203 | train_acc, train_acc_top5, train_loss = reduced.tolist()
204 |
205 | if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0:
206 | print(' * Epoch {}, Acc@1 {:.3f}, Acc@5 {:.3f}, Time {:.2f}'.format(epoch, train_acc, train_acc_top5, time2 - time1))
207 |
208 | logger.log_value('train_acc', train_acc, epoch)
209 | logger.log_value('train_loss', train_loss, epoch)
210 |
211 | test_acc, test_acc_top5, test_loss = validate_vanilla(val_loader, model, criterion, opt)
212 |
213 | if opt.dali is not None:
214 | train_loader.reset()
215 | val_loader.reset()
216 |
217 | if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0:
218 | print(' ** Acc@1 {:.3f}, Acc@5 {:.3f}'.format(test_acc, test_acc_top5))
219 |
220 | logger.log_value('test_acc', test_acc, epoch)
221 | logger.log_value('test_acc_top5', test_acc_top5, epoch)
222 | logger.log_value('test_loss', test_loss, epoch)
223 |
224 | # save the best model
225 | if test_acc > best_acc:
226 | best_acc = test_acc
227 | state = {
228 | 'epoch': epoch,
229 | 'best_acc': best_acc,
230 | 'model': model.module.state_dict() if opt.multiprocessing_distributed else model.state_dict(),
231 | }
232 | save_file = os.path.join(opt.save_folder, '{}_best.pth'.format(opt.model))
233 |
234 | test_merics = { 'test_loss': float('%.2f' % test_loss),
235 | 'test_acc': float('%.2f' % test_acc),
236 | 'test_acc_top5': float('%.2f' % test_acc_top5),
237 | 'epoch': epoch}
238 |
239 | save_dict_to_json(test_merics, os.path.join(opt.save_folder, "test_best_metrics.json"))
240 |
241 | print('saving the best model!')
242 | torch.save(state, save_file)
243 |
244 | if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0:
245 | # This best accuracy is only for printing purpose.
246 | print('best accuracy:', best_acc)
247 |
248 | # save parameters
249 | state = {k: v for k, v in opt._get_kwargs()}
250 |
251 | # No. parameters(M)
252 | num_params = (sum(p.numel() for p in model.parameters())/1000000.0)
253 | state['Total params'] = num_params
254 | state['Total time'] = float('%.2f' % ((time.time() - total_time) / 3600.0))
255 | params_json_path = os.path.join(opt.save_folder, "parameters.json")
256 | save_dict_to_json(state, params_json_path)
257 |
258 | if __name__ == '__main__':
259 | main()
260 |
--------------------------------------------------------------------------------