├── README.md
├── cifar.py
├── command.sh
├── frm.png
├── models
├── ShuffleNetv1.py
├── ShuffleNetv2.py
├── __init__.py
├── classifier.py
├── mobilenetv2.py
├── resnet.py
├── resnetv2.py
├── util.py
├── vgg.py
└── wrn.py
├── student.py
├── teacher.py
├── utils.py
└── wrapper.py
/README.md:
--------------------------------------------------------------------------------
1 | # SSKD
2 | This repo is the implementation of paper [Knowledge Distillation Meets Self-Supervision](https://arxiv.org/abs/2006.07114) (ECCV 2020).
3 |
4 |
5 |
6 | ## Prerequisite
7 | This repo is tested with Ubuntu 16.04.5, Python 3.7, PyTorch 1.5.0, CUDA 10.2.
8 | Make sure to install pytorch, torchvision, tensorboardX, numpy before using this repo.
9 |
10 | ## Running
11 |
12 | ### Teacher Training
13 | An example of teacher training is:
14 | ```
15 | python teacher.py --arch wrn_40_2 --lr 0.05 --gpu-id 0
16 | ```
17 | where you can specify the architecture via flag `--arch`
18 |
19 | You can also download all the pre-trained teacher models [here](https://drive.google.com/drive/folders/1vJ0VdeFRd9a50ObbBD8SslBtmqmj8p8r?usp=sharing).
20 | If you want to run `student.py` directly, you have to re-organise the directory. For instance, when you download *vgg13.pth*, you have to make a directory for it, say *teacher_vgg13*, and then make a new directory *ckpt* inside *teacher_vgg13*. Move the *vgg13.pth* into *teacher_vgg13/ckpt* and rename it as *best.pth*. If you want a simpler way to use pre-trained model, you can edit the code in `student.py` (line 90).
21 |
22 | ### Student Training
23 | An example of student training is:
24 | ```
25 | python student.py --t-path ./experiments/teacher_wrn_40_2_seed0/ --s-arch wrn_16_2 --lr 0.05 --gpu-id 0
26 | ```
27 | The meanings of flags are:
28 | > `--t-path`: teacher's checkpoint path. Automatically search the checkpoint containing 'best' keyword in its name.
29 |
30 | > `--s-arch`: student's architecture.
31 |
32 | All the commands can be found in `command.sh`
33 |
34 | ## Results (Top-1 Acc) on CIFAR100
35 |
36 | ### Similar-Architecture
37 |
38 | | Teacher
Student | wrn40-2
wrn16-2 | wrn40-2
wrn40-1 | resnet56
resnet20 | resnet32x4
resnet8x4 | vgg13
vgg8 |
39 | |:---------------:|:-----------------:|:-----------------:|:-----------------:|:--------------------:|:-----------:|
40 | | Teacher
Student | 76.46
73.64 | 76.46
72.24 | 73.44
69.63 | 79.63
72.51 | 75.38
70.68 |
41 | | KD | 74.92 | 73.54 | 70.66 | 73.33 | 72.98 |
42 | | FitNet | 75.75 | 74.12 | 71.60 | 74.31 | 73.54 |
43 | | AT | 75.28 | 74.45 | **71.78** | 74.26 | 73.62 |
44 | | SP | 75.34 | 73.15 | 71.48 | 74.74 | 73.44 |
45 | | VID | 74.79 | 74.20 | 71.71 | 74.82 | 73.96 |
46 | | RKD | 75.40 | 73.87 | 71.48 | 74.47 | 73.72 |
47 | | PKT | 76.01 | 74.40 | 71.44 | 74.17 | 73.37 |
48 | | AB | 68.89 | 75.06 | 71.49 | 74.45 | 74.27 |
49 | | FT | 75.15 | 74.37 | 71.52 | 75.02 | 73.42 |
50 | | CRD | **76.04** | 75.52 | 71.68 | 75.90 | 74.06 |
51 | | **SSKD** | **76.04** | **76.13** | 71.49 | **76.20** | **75.33** |
52 |
53 | ### Cross-Architecture
54 |
55 | | Teacher
Student | vgg13
MobieleNetV2 | ResNet50
MobileNetV2 | ResNet50
vgg8 | resnet32x4
ShuffleV1 | resnet32x4
ShuffleV2 | wrn40-2
ShuffleV1|
56 | |:---------------:|:-----------------:|:-----------------:|:-----------------:|:--------------------:|:-----------:|:-------------:|
57 | | Teacher
Student | 75.38
65.79 | 79.10
65.79 | 79.10
70.68 | 79.63
70.77 | 79.63
73.12 | 76.46
70.77 |
58 | | KD | 67.37 | 67.35| 73.81| 74.07| 74.45| 74.83|
59 | | FitNet |68.58 | 68.54 | 73.84 | 74.82 | 75.11 | 75.55 |
60 | | AT | 69.34 | 69.28 | 73.45 | 74.76 | 75.30 | 75.61 |
61 | | SP | 66.89 | 68.99 | 73.86 | 73.80 | 75.15 | 75.56 |
62 | | VID | 66.91 | 68.88 | 73.75 | 74.28 | 75.78 | 75.36 |
63 | | RKD | 68.50 | 68.46 | 73.73 | 74.20 | 75.74 | 75.45 |
64 | | PKT | 67.89 | 68.44 | 73.53 | 74.06 | 75.18 | 75.51 |
65 | | AB | 68.86 | 69.32 | 74.20 | 76.24 | 75.66 | 76.58 |
66 | | FT | 69.19 | 69.01 | 73.58 | 74.31 | 74.95 | 75.18 |
67 | | CRD | 68.49 | 70.32 | 74.42 | 75.46 | 75.72 | 75.96 |
68 | | **SSKD** | **71.53** | **72.57** | **75.76** | **78.44** | **78.61** | **77.40** |
69 |
70 | ## Citation
71 | If you find this repo useful for your research, please consider citing the paper
72 | ```
73 | @inproceedings{xu2020knowledge,
74 | title={Knowledge Distillation Meets Self-Supervision},
75 | author={Xu, Guodong and Liu, Ziwei and Li, Xiaoxiao and Loy, Chen Change},
76 | booktitle={European Conference on Computer Vision (ECCV)},
77 | year={2020},
78 | }
79 | ```
80 | ## Acknowledgement
81 | The implementation of `models` is borrowed from [CRD](https://github.com/HobbitLong/RepDistiller)
82 |
--------------------------------------------------------------------------------
/cifar.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from PIL import Image
3 | import os
4 | import os.path
5 | import numpy as np
6 | import sys
7 |
8 | import pickle
9 | import torch
10 | import torch.utils.data as data
11 |
12 | from itertools import permutations
13 |
14 | class VisionDataset(data.Dataset):
15 | _repr_indent = 4
16 |
17 | def __init__(self, root, transforms=None, transform=None, target_transform=None):
18 | if isinstance(root, torch._six.string_classes):
19 | root = os.path.expanduser(root)
20 | self.root = root
21 |
22 | has_transforms = transforms is not None
23 | has_separate_transform = transform is not None or target_transform is not None
24 | if has_transforms and has_separate_transform:
25 | raise ValueError("Only transforms or transform/target_transform can "
26 | "be passed as argument")
27 |
28 | # for backwards-compatibility
29 | self.transform = transform
30 | self.target_transform = target_transform
31 |
32 | if has_separate_transform:
33 | transforms = StandardTransform(transform, target_transform)
34 | self.transforms = transforms
35 |
36 | def __getitem__(self, index):
37 | raise NotImplementedError
38 |
39 | def __len__(self):
40 | raise NotImplementedError
41 |
42 | def __repr__(self):
43 | head = "Dataset " + self.__class__.__name__
44 | body = ["Number of datapoints: {}".format(self.__len__())]
45 | if self.root is not None:
46 | body.append("Root location: {}".format(self.root))
47 | body += self.extra_repr().splitlines()
48 | if self.transforms is not None:
49 | body += [repr(self.transforms)]
50 | lines = [head] + [" " * self._repr_indent + line for line in body]
51 | return '\n'.join(lines)
52 |
53 | def _format_transform_repr(self, transform, head):
54 | lines = transform.__repr__().splitlines()
55 | return (["{}{}".format(head, lines[0])] +
56 | ["{}{}".format(" " * len(head), line) for line in lines[1:]])
57 |
58 | def extra_repr(self):
59 | return ""
60 |
61 | class CIFAR10(VisionDataset):
62 | base_folder = 'cifar-10-batches-py'
63 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
64 | filename = "cifar-10-python.tar.gz"
65 | tgz_md5 = 'c58f30108f718f92721af3b95e74349a'
66 | train_list = [
67 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
68 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
69 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
70 | ['data_batch_4', '634d18415352ddfa80567beed471001a'],
71 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
72 | ]
73 |
74 | test_list = [
75 | ['test_batch', '40351d587109b95175f43aff81a1287e'],
76 | ]
77 | meta = {
78 | 'filename': 'batches.meta',
79 | 'key': 'label_names',
80 | 'md5': '5ff9c542aee3614f3951f8cda6e48888',
81 | }
82 |
83 | def __init__(self, root, train=True,
84 | transform=None, download=False):
85 |
86 | super(CIFAR10, self).__init__(root)
87 | self.transform = transform
88 |
89 | self.train = train # training set or test set
90 |
91 | if download:
92 | raise ValueError('cannot download.')
93 | exit()
94 | #self.download()
95 |
96 | #if not self._check_integrity():
97 | # raise RuntimeError('Dataset not found or corrupted.' +
98 | # ' You can use download=True to download it')
99 |
100 | if self.train:
101 | downloaded_list = self.train_list
102 | else:
103 | downloaded_list = self.test_list
104 |
105 | self.data = []
106 | self.targets = []
107 |
108 | # now load the picked numpy arrays
109 | for file_name, checksum in downloaded_list:
110 | file_path = os.path.join(self.root, self.base_folder, file_name)
111 | with open(file_path, 'rb') as f:
112 | if sys.version_info[0] == 2:
113 | entry = pickle.load(f)
114 | else:
115 | entry = pickle.load(f, encoding='latin1')
116 | self.data.append(entry['data'])
117 | if 'labels' in entry:
118 | self.targets.extend(entry['labels'])
119 | else:
120 | self.targets.extend(entry['fine_labels'])
121 |
122 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
123 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
124 |
125 | self._load_meta()
126 |
127 | def _load_meta(self):
128 | path = os.path.join(self.root, self.base_folder, self.meta['filename'])
129 | #if not check_integrity(path, self.meta['md5']):
130 | # raise RuntimeError('Dataset metadata file not found or corrupted.' +
131 | # ' You can use download=True to download it')
132 | with open(path, 'rb') as infile:
133 | if sys.version_info[0] == 2:
134 | data = pickle.load(infile)
135 | else:
136 | data = pickle.load(infile, encoding='latin1')
137 | self.classes = data[self.meta['key']]
138 | self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
139 |
140 | def __getitem__(self, index):
141 |
142 | img, target = self.data[index], self.targets[index]
143 | if self.train:
144 | if np.random.rand() < 0.5:
145 | img = img[:,::-1,:]
146 |
147 | img0 = np.rot90(img, 0).copy()
148 | img0 = Image.fromarray(img0)
149 | img0 = self.transform(img0)
150 |
151 | img1 = np.rot90(img, 1).copy()
152 | img1 = Image.fromarray(img1)
153 | img1 = self.transform(img1)
154 |
155 | img2 = np.rot90(img, 2).copy()
156 | img2 = Image.fromarray(img2)
157 | img2 = self.transform(img2)
158 |
159 | img3 = np.rot90(img, 3).copy()
160 | img3 = Image.fromarray(img3)
161 | img3 = self.transform(img3)
162 |
163 | img = torch.stack([img0,img1,img2,img3])
164 |
165 | return img, target
166 |
167 |
168 | def __len__(self):
169 | return len(self.data)
170 |
171 | def _check_integrity(self):
172 | root = self.root
173 | for fentry in (self.train_list + self.test_list):
174 | filename, md5 = fentry[0], fentry[1]
175 | fpath = os.path.join(root, self.base_folder, filename)
176 | if not check_integrity(fpath, md5):
177 | return False
178 | return True
179 |
180 | def download(self):
181 | import tarfile
182 |
183 | if self._check_integrity():
184 | print('Files already downloaded and verified')
185 | return
186 |
187 | download_url(self.url, self.root, self.filename, self.tgz_md5)
188 |
189 | # extract file
190 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
191 | tar.extractall(path=self.root)
192 |
193 | def extra_repr(self):
194 | return "Split: {}".format("Train" if self.train is True else "Test")
195 |
196 |
197 | class CIFAR100(CIFAR10):
198 | """`CIFAR100 `_ Dataset.
199 |
200 | This is a subclass of the `CIFAR10` Dataset.
201 | """
202 | base_folder = 'cifar-100-python'
203 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
204 | filename = "cifar-100-python.tar.gz"
205 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
206 | train_list = [
207 | ['train', '16019d7e3df5f24257cddd939b257f8d'],
208 | ]
209 |
210 | test_list = [
211 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
212 | ]
213 | meta = {
214 | 'filename': 'meta',
215 | 'key': 'fine_label_names',
216 | 'md5': '7973b15100ade9c7d40fb424638fde48',
217 | }
218 |
--------------------------------------------------------------------------------
/command.sh:
--------------------------------------------------------------------------------
1 | # teacher training
2 | python teacher.py --arch wrn_40_2 --lr 0.05 --gpu-id 0
3 | python teacher.py --arch wrn_40_1 --lr 0.05 --gpu-id 0
4 | python teacher.py --arch wrn_16_2 --lr 0.05 --gpu-id 0
5 | python teacher.py --arch vgg13 --lr 0.05 --gpu-id 0
6 | python teacher.py --arch vgg8 --lr 0.05 --gpu-id 0
7 | python teacher.py --arch resnet56 --lr 0.05 --gpu-id 0
8 | python teacher.py --arch resnet20 --lr 0.05 --gpu-id 0
9 | python teacher.py --arch resnet32x4 --lr 0.05 --gpu-id 0
10 | python teacher.py --arch resnet8x4 --lr 0.05 --gpu-id 0
11 | python teacher.py --arch ResNet50 --lr 0.05 --gpu-id 0
12 | python teacher.py --arch ShuffleV1 --lr 0.01 --gpu-id 0
13 | python teacher.py --arch ShuffleV2 --lr 0.01 --gpu-id 0
14 | python teacher.py --arch MobileNetV2 --lr 0.01 --gpu-id 0
15 |
16 |
17 | # student training
18 |
19 | # similar-architecture
20 | python student.py --t-path ./experiments/teacher_wrn_40_2_seed0/ --s-arch wrn_16_2 --lr 0.05 --gpu-id 0
21 | python student.py --t-path ./experiments/teacher_wrn_40_2_seed0/ --s-arch wrn_40_1 --lr 0.05 --gpu-id 0
22 | python student.py --t-path ./experiments/teacher_resnet56_seed0/ --s-arch resnet20 --lr 0.05 --gpu-id 0
23 | python student.py --t-path ./experiments/teacher_resnet32x4_seed0/ --s-arch resnet8x4 --lr 0.05 --gpu-id 0
24 | python student.py --t-path ./experiments/teacher_vgg13_seed0/ --s-arch vgg8 --lr 0.05 --gpu-id 0
25 | # different-architecture
26 | python student.py --t-path ./experiments/teacher_vgg13_seed0/ --s-arch MobileNetV2 --lr 0.01 --gpu-id 0
27 | python student.py --t-path ./experiments/teacher_ResNet50_seed0/ --s-arch MobileNetV2 --lr 0.01 --gpu-id 0
28 | python student.py --t-path ./experiments/teacher_ResNet50_seed0/ --s-arch vgg8 --lr 0.05 --gpu-id 0
29 | python student.py --t-path ./experiments/teacher_resnet32x4_seed0/ --s-arch ShuffleV1 --lr 0.01 --gpu-id 0
30 | python student.py --t-path ./experiments/teacher_resnet32x4_seed0/ --s-arch ShuffleV2 --lr 0.01 --gpu-id 0
31 | python student.py --t-path ./experiments/teacher_wrn_40_2_seed0/ --s-arch ShuffleV1 --lr 0.01 --gpu-id 0
32 |
33 |
--------------------------------------------------------------------------------
/frm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuguodong03/SSKD/661b972c124a83c32dcd0e203390ba637075fe92/frm.png
--------------------------------------------------------------------------------
/models/ShuffleNetv1.py:
--------------------------------------------------------------------------------
1 | '''ShuffleNet in PyTorch.
2 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details.
3 | '''
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class ShuffleBlock(nn.Module):
10 | def __init__(self, groups):
11 | super(ShuffleBlock, self).__init__()
12 | self.groups = groups
13 |
14 | def forward(self, x):
15 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
16 | N,C,H,W = x.size()
17 | g = self.groups
18 | return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).reshape(N,C,H,W)
19 |
20 |
21 | class Bottleneck(nn.Module):
22 | def __init__(self, in_planes, out_planes, stride, groups, is_last=False):
23 | super(Bottleneck, self).__init__()
24 | self.is_last = is_last
25 | self.stride = stride
26 |
27 | mid_planes = int(out_planes/4)
28 | g = 1 if in_planes == 24 else groups
29 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False)
30 | self.bn1 = nn.BatchNorm2d(mid_planes)
31 | self.shuffle1 = ShuffleBlock(groups=g)
32 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False)
33 | self.bn2 = nn.BatchNorm2d(mid_planes)
34 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False)
35 | self.bn3 = nn.BatchNorm2d(out_planes)
36 |
37 | self.shortcut = nn.Sequential()
38 | if stride == 2:
39 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1))
40 |
41 | def forward(self, x):
42 | out = F.relu(self.bn1(self.conv1(x)))
43 | out = self.shuffle1(out)
44 | out = F.relu(self.bn2(self.conv2(out)))
45 | out = self.bn3(self.conv3(out))
46 | res = self.shortcut(x)
47 | preact = torch.cat([out, res], 1) if self.stride == 2 else out+res
48 | out = F.relu(preact)
49 | # out = F.relu(torch.cat([out, res], 1)) if self.stride == 2 else F.relu(out+res)
50 | if self.is_last:
51 | return out, preact
52 | else:
53 | return out
54 |
55 |
56 | class ShuffleNet(nn.Module):
57 | def __init__(self, cfg, num_classes=10):
58 | super(ShuffleNet, self).__init__()
59 | out_planes = cfg['out_planes']
60 | num_blocks = cfg['num_blocks']
61 | groups = cfg['groups']
62 |
63 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False)
64 | self.bn1 = nn.BatchNorm2d(24)
65 | self.in_planes = 24
66 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups)
67 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups)
68 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups)
69 | self.linear = nn.Linear(out_planes[2], num_classes)
70 |
71 | def _make_layer(self, out_planes, num_blocks, groups):
72 | layers = []
73 | for i in range(num_blocks):
74 | stride = 2 if i == 0 else 1
75 | cat_planes = self.in_planes if i == 0 else 0
76 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes,
77 | stride=stride,
78 | groups=groups,
79 | is_last=(i == num_blocks - 1)))
80 | self.in_planes = out_planes
81 | return nn.Sequential(*layers)
82 |
83 | def get_feat_modules(self):
84 | feat_m = nn.ModuleList([])
85 | feat_m.append(self.conv1)
86 | feat_m.append(self.bn1)
87 | feat_m.append(self.layer1)
88 | feat_m.append(self.layer2)
89 | feat_m.append(self.layer3)
90 | return feat_m
91 |
92 | def get_bn_before_relu(self):
93 | raise NotImplementedError('ShuffleNet currently is not supported for "Overhaul" teacher')
94 |
95 | def forward(self, x, is_feat=False, preact=False):
96 | out = F.relu(self.bn1(self.conv1(x)))
97 | f0 = out
98 | out, f1_pre = self.layer1(out)
99 | f1 = out
100 | out, f2_pre = self.layer2(out)
101 | f2 = out
102 | out, f3_pre = self.layer3(out)
103 | f3 = out
104 | out = F.avg_pool2d(out, 4)
105 | out = out.view(out.size(0), -1)
106 | f4 = out
107 | out = self.linear(out)
108 |
109 | if is_feat:
110 | if preact:
111 | return [f0, f1_pre, f2_pre, f3_pre, f4], out
112 | else:
113 | return [f0, f1, f2, f3, f4], out
114 | else:
115 | return out
116 |
117 |
118 | def ShuffleV1(**kwargs):
119 | cfg = {
120 | 'out_planes': [240, 480, 960],
121 | 'num_blocks': [4, 8, 4],
122 | 'groups': 3
123 | }
124 | return ShuffleNet(cfg, **kwargs)
125 |
126 |
127 | if __name__ == '__main__':
128 |
129 | x = torch.randn(2, 3, 32, 32)
130 | net = ShuffleV1(num_classes=100)
131 | import time
132 | a = time.time()
133 | feats, logit = net(x, is_feat=True, preact=True)
134 | b = time.time()
135 | print(b - a)
136 | for f in feats:
137 | print(f.shape, f.min().item())
138 | print(logit.shape)
139 |
--------------------------------------------------------------------------------
/models/ShuffleNetv2.py:
--------------------------------------------------------------------------------
1 | '''ShuffleNetV2 in PyTorch.
2 | See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details.
3 | '''
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class ShuffleBlock(nn.Module):
10 | def __init__(self, groups=2):
11 | super(ShuffleBlock, self).__init__()
12 | self.groups = groups
13 |
14 | def forward(self, x):
15 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
16 | N, C, H, W = x.size()
17 | g = self.groups
18 | return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W)
19 |
20 |
21 | class SplitBlock(nn.Module):
22 | def __init__(self, ratio):
23 | super(SplitBlock, self).__init__()
24 | self.ratio = ratio
25 |
26 | def forward(self, x):
27 | c = int(x.size(1) * self.ratio)
28 | return x[:, :c, :, :], x[:, c:, :, :]
29 |
30 |
31 | class BasicBlock(nn.Module):
32 | def __init__(self, in_channels, split_ratio=0.5, is_last=False):
33 | super(BasicBlock, self).__init__()
34 | self.is_last = is_last
35 | self.split = SplitBlock(split_ratio)
36 | in_channels = int(in_channels * split_ratio)
37 | self.conv1 = nn.Conv2d(in_channels, in_channels,
38 | kernel_size=1, bias=False)
39 | self.bn1 = nn.BatchNorm2d(in_channels)
40 | self.conv2 = nn.Conv2d(in_channels, in_channels,
41 | kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False)
42 | self.bn2 = nn.BatchNorm2d(in_channels)
43 | self.conv3 = nn.Conv2d(in_channels, in_channels,
44 | kernel_size=1, bias=False)
45 | self.bn3 = nn.BatchNorm2d(in_channels)
46 | self.shuffle = ShuffleBlock()
47 |
48 | def forward(self, x):
49 | x1, x2 = self.split(x)
50 | out = F.relu(self.bn1(self.conv1(x2)))
51 | out = self.bn2(self.conv2(out))
52 | preact = self.bn3(self.conv3(out))
53 | out = F.relu(preact)
54 | # out = F.relu(self.bn3(self.conv3(out)))
55 | preact = torch.cat([x1, preact], 1)
56 | out = torch.cat([x1, out], 1)
57 | out = self.shuffle(out)
58 | if self.is_last:
59 | return out, preact
60 | else:
61 | return out
62 |
63 |
64 | class DownBlock(nn.Module):
65 | def __init__(self, in_channels, out_channels):
66 | super(DownBlock, self).__init__()
67 | mid_channels = out_channels // 2
68 | # left
69 | self.conv1 = nn.Conv2d(in_channels, in_channels,
70 | kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False)
71 | self.bn1 = nn.BatchNorm2d(in_channels)
72 | self.conv2 = nn.Conv2d(in_channels, mid_channels,
73 | kernel_size=1, bias=False)
74 | self.bn2 = nn.BatchNorm2d(mid_channels)
75 | # right
76 | self.conv3 = nn.Conv2d(in_channels, mid_channels,
77 | kernel_size=1, bias=False)
78 | self.bn3 = nn.BatchNorm2d(mid_channels)
79 | self.conv4 = nn.Conv2d(mid_channels, mid_channels,
80 | kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False)
81 | self.bn4 = nn.BatchNorm2d(mid_channels)
82 | self.conv5 = nn.Conv2d(mid_channels, mid_channels,
83 | kernel_size=1, bias=False)
84 | self.bn5 = nn.BatchNorm2d(mid_channels)
85 |
86 | self.shuffle = ShuffleBlock()
87 |
88 | def forward(self, x):
89 | # left
90 | out1 = self.bn1(self.conv1(x))
91 | out1 = F.relu(self.bn2(self.conv2(out1)))
92 | # right
93 | out2 = F.relu(self.bn3(self.conv3(x)))
94 | out2 = self.bn4(self.conv4(out2))
95 | out2 = F.relu(self.bn5(self.conv5(out2)))
96 | # concat
97 | out = torch.cat([out1, out2], 1)
98 | out = self.shuffle(out)
99 | return out
100 |
101 |
102 | class ShuffleNetV2(nn.Module):
103 | def __init__(self, net_size, num_classes=10):
104 | super(ShuffleNetV2, self).__init__()
105 | out_channels = configs[net_size]['out_channels']
106 | num_blocks = configs[net_size]['num_blocks']
107 |
108 | # self.conv1 = nn.Conv2d(3, 24, kernel_size=3,
109 | # stride=1, padding=1, bias=False)
110 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False)
111 | self.bn1 = nn.BatchNorm2d(24)
112 | self.in_channels = 24
113 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0])
114 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1])
115 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2])
116 | self.conv2 = nn.Conv2d(out_channels[2], out_channels[3],
117 | kernel_size=1, stride=1, padding=0, bias=False)
118 | self.bn2 = nn.BatchNorm2d(out_channels[3])
119 | self.linear = nn.Linear(out_channels[3], num_classes)
120 |
121 | def _make_layer(self, out_channels, num_blocks):
122 | layers = [DownBlock(self.in_channels, out_channels)]
123 | for i in range(num_blocks):
124 | layers.append(BasicBlock(out_channels, is_last=(i == num_blocks - 1)))
125 | self.in_channels = out_channels
126 | return nn.Sequential(*layers)
127 |
128 | def get_feat_modules(self):
129 | feat_m = nn.ModuleList([])
130 | feat_m.append(self.conv1)
131 | feat_m.append(self.bn1)
132 | feat_m.append(self.layer1)
133 | feat_m.append(self.layer2)
134 | feat_m.append(self.layer3)
135 | return feat_m
136 |
137 | def get_bn_before_relu(self):
138 | raise NotImplementedError('ShuffleNetV2 currently is not supported for "Overhaul" teacher')
139 |
140 | def forward(self, x, is_feat=False, preact=False):
141 | out = F.relu(self.bn1(self.conv1(x)))
142 | # out = F.max_pool2d(out, 3, stride=2, padding=1)
143 | f0 = out
144 | out, f1_pre = self.layer1(out)
145 | f1 = out
146 | out, f2_pre = self.layer2(out)
147 | f2 = out
148 | out, f3_pre = self.layer3(out)
149 | f3 = out
150 | out = F.relu(self.bn2(self.conv2(out)))
151 | out = F.avg_pool2d(out, 4)
152 | out = out.view(out.size(0), -1)
153 | f4 = out
154 | out = self.linear(out)
155 | if is_feat:
156 | if preact:
157 | return [f0, f1_pre, f2_pre, f3_pre, f4], out
158 | else:
159 | return [f0, f1, f2, f3, f4], out
160 | else:
161 | return out
162 |
163 |
164 | configs = {
165 | 0.2: {
166 | 'out_channels': (40, 80, 160, 512),
167 | 'num_blocks': (3, 3, 3)
168 | },
169 |
170 | 0.3: {
171 | 'out_channels': (40, 80, 160, 512),
172 | 'num_blocks': (3, 7, 3)
173 | },
174 |
175 | 0.5: {
176 | 'out_channels': (48, 96, 192, 1024),
177 | 'num_blocks': (3, 7, 3)
178 | },
179 |
180 | 1: {
181 | 'out_channels': (116, 232, 464, 1024),
182 | 'num_blocks': (3, 7, 3)
183 | },
184 | 1.5: {
185 | 'out_channels': (176, 352, 704, 1024),
186 | 'num_blocks': (3, 7, 3)
187 | },
188 | 2: {
189 | 'out_channels': (224, 488, 976, 2048),
190 | 'num_blocks': (3, 7, 3)
191 | }
192 | }
193 |
194 |
195 | def ShuffleV2(**kwargs):
196 | model = ShuffleNetV2(net_size=1, **kwargs)
197 | return model
198 |
199 |
200 | if __name__ == '__main__':
201 | net = ShuffleV2(num_classes=100)
202 | x = torch.randn(3, 3, 32, 32)
203 | import time
204 | a = time.time()
205 | feats, logit = net(x, is_feat=True, preact=True)
206 | b = time.time()
207 | print(b - a)
208 | for f in feats:
209 | print(f.shape, f.min().item())
210 | print(logit.shape)
211 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet import resnet8, resnet14, resnet20, resnet32, resnet44, resnet56, resnet110, resnet8x4, resnet32x4, resnet14x05, resnet20x05, resnet20x0375
2 | from .resnetv2 import ResNet50
3 | from .wrn import wrn_16_1, wrn_16_2, wrn_40_1, wrn_40_2
4 | from .vgg import vgg19_bn, vgg16_bn, vgg13_bn, vgg11_bn, vgg8_bn
5 | from .mobilenetv2 import mobile_half
6 | from .ShuffleNetv1 import ShuffleV1
7 | from .ShuffleNetv2 import ShuffleV2
8 |
9 | model_dict = {
10 | 'resnet8': resnet8,
11 | 'resnet14': resnet14,
12 | 'resnet20': resnet20,
13 | 'resnet32': resnet32,
14 | 'resnet44': resnet44,
15 | 'resnet56': resnet56,
16 | 'resnet110': resnet110,
17 | 'resnet8x4': resnet8x4,
18 | 'resnet32x4': resnet32x4,
19 | 'ResNet50': ResNet50,
20 | 'wrn_16_1': wrn_16_1,
21 | 'wrn_16_2': wrn_16_2,
22 | 'wrn_40_1': wrn_40_1,
23 | 'wrn_40_2': wrn_40_2,
24 | 'vgg8': vgg8_bn,
25 | 'vgg11': vgg11_bn,
26 | 'vgg13': vgg13_bn,
27 | 'vgg16': vgg16_bn,
28 | 'vgg19': vgg19_bn,
29 | 'MobileNetV2': mobile_half,
30 | 'ShuffleV1': ShuffleV1,
31 | 'ShuffleV2': ShuffleV2,
32 | 'resnet14x05': resnet14x05,
33 | 'resnet20x05': resnet20x05,
34 | 'resnet20x0375': resnet20x0375,
35 | }
36 |
--------------------------------------------------------------------------------
/models/classifier.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch.nn as nn
4 |
5 |
6 | #########################################
7 | # ===== Classifiers ===== #
8 | #########################################
9 |
10 | class LinearClassifier(nn.Module):
11 |
12 | def __init__(self, dim_in, n_label=10):
13 | super(LinearClassifier, self).__init__()
14 |
15 | self.net = nn.Linear(dim_in, n_label)
16 |
17 | def forward(self, x):
18 | return self.net(x)
19 |
20 |
21 | class NonLinearClassifier(nn.Module):
22 |
23 | def __init__(self, dim_in, n_label=10, p=0.1):
24 | super(NonLinearClassifier, self).__init__()
25 |
26 | self.net = nn.Sequential(
27 | nn.Linear(dim_in, 200),
28 | nn.Dropout(p=p),
29 | nn.BatchNorm1d(200),
30 | nn.ReLU(inplace=True),
31 | nn.Linear(200, n_label),
32 | )
33 |
34 | def forward(self, x):
35 | return self.net(x)
36 |
--------------------------------------------------------------------------------
/models/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | """
2 | MobileNetV2 implementation used in
3 |
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | import math
9 |
10 | __all__ = ['mobilenetv2_T_w', 'mobile_half']
11 |
12 | BN = None
13 |
14 |
15 | def conv_bn(inp, oup, stride):
16 | return nn.Sequential(
17 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
18 | nn.BatchNorm2d(oup),
19 | nn.ReLU(inplace=True)
20 | )
21 |
22 |
23 | def conv_1x1_bn(inp, oup):
24 | return nn.Sequential(
25 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
26 | nn.BatchNorm2d(oup),
27 | nn.ReLU(inplace=True)
28 | )
29 |
30 |
31 | class InvertedResidual(nn.Module):
32 | def __init__(self, inp, oup, stride, expand_ratio):
33 | super(InvertedResidual, self).__init__()
34 | self.blockname = None
35 |
36 | self.stride = stride
37 | assert stride in [1, 2]
38 |
39 | self.use_res_connect = self.stride == 1 and inp == oup
40 |
41 | self.conv = nn.Sequential(
42 | # pw
43 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False),
44 | nn.BatchNorm2d(inp * expand_ratio),
45 | nn.ReLU(inplace=True),
46 | # dw
47 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, stride, 1, groups=inp * expand_ratio, bias=False),
48 | nn.BatchNorm2d(inp * expand_ratio),
49 | nn.ReLU(inplace=True),
50 | # pw-linear
51 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False),
52 | nn.BatchNorm2d(oup),
53 | )
54 | self.names = ['0', '1', '2', '3', '4', '5', '6', '7']
55 |
56 | def forward(self, x):
57 | t = x
58 | if self.use_res_connect:
59 | return t + self.conv(x)
60 | else:
61 | return self.conv(x)
62 |
63 |
64 | class MobileNetV2(nn.Module):
65 | """mobilenetV2"""
66 | def __init__(self, T,
67 | feature_dim,
68 | input_size=32,
69 | width_mult=1.,
70 | remove_avg=False):
71 | super(MobileNetV2, self).__init__()
72 | self.remove_avg = remove_avg
73 |
74 | # setting of inverted residual blocks
75 | self.interverted_residual_setting = [
76 | # t, c, n, s
77 | [1, 16, 1, 1],
78 | [T, 24, 2, 1],
79 | [T, 32, 3, 2],
80 | [T, 64, 4, 2],
81 | [T, 96, 3, 1],
82 | [T, 160, 3, 2],
83 | [T, 320, 1, 1],
84 | ]
85 |
86 | # building first layer
87 | assert input_size % 32 == 0
88 | input_channel = int(32 * width_mult)
89 | self.conv1 = conv_bn(3, input_channel, 2)
90 |
91 | # building inverted residual blocks
92 | self.blocks = nn.ModuleList([])
93 | for t, c, n, s in self.interverted_residual_setting:
94 | output_channel = int(c * width_mult)
95 | layers = []
96 | strides = [s] + [1] * (n - 1)
97 | for stride in strides:
98 | layers.append(
99 | InvertedResidual(input_channel, output_channel, stride, t)
100 | )
101 | input_channel = output_channel
102 | self.blocks.append(nn.Sequential(*layers))
103 |
104 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280
105 | self.conv2 = conv_1x1_bn(input_channel, self.last_channel)
106 |
107 | H = input_size // (32//2)
108 | self.avgpool = nn.AvgPool2d(H, ceil_mode=True)
109 |
110 | # building classifier
111 | #self.classifier = nn.Sequential(
112 | # # nn.Dropout(0.5),
113 | # nn.Linear(self.last_channel, feature_dim),
114 | #)
115 | self.classifier = nn.Linear(self.last_channel, feature_dim)
116 |
117 | self._initialize_weights()
118 | print(T, width_mult)
119 |
120 | def get_bn_before_relu(self):
121 | bn1 = self.blocks[1][-1].conv[-1]
122 | bn2 = self.blocks[2][-1].conv[-1]
123 | bn3 = self.blocks[4][-1].conv[-1]
124 | bn4 = self.blocks[6][-1].conv[-1]
125 | return [bn1, bn2, bn3, bn4]
126 |
127 | def get_feat_modules(self):
128 | feat_m = nn.ModuleList([])
129 | feat_m.append(self.conv1)
130 | feat_m.append(self.blocks)
131 | return feat_m
132 |
133 | def forward(self, x, is_feat=False, preact=False):
134 |
135 | out = self.conv1(x)
136 | f0 = out
137 |
138 | out = self.blocks[0](out)
139 | out = self.blocks[1](out)
140 | f1 = out
141 | out = self.blocks[2](out)
142 | f2 = out
143 | out = self.blocks[3](out)
144 | out = self.blocks[4](out)
145 | f3 = out
146 | out = self.blocks[5](out)
147 | out = self.blocks[6](out)
148 | f4 = out
149 |
150 | out = self.conv2(out)
151 |
152 | if not self.remove_avg:
153 | out = self.avgpool(out)
154 | out = out.view(out.size(0), -1)
155 | f5 = out
156 | out = self.classifier(out)
157 |
158 | if is_feat:
159 | return [f0, f1, f2, f3, f4, f5], out
160 | else:
161 | return out
162 |
163 | def _initialize_weights(self):
164 | for m in self.modules():
165 | if isinstance(m, nn.Conv2d):
166 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
167 | m.weight.data.normal_(0, math.sqrt(2. / n))
168 | if m.bias is not None:
169 | m.bias.data.zero_()
170 | elif isinstance(m, nn.BatchNorm2d):
171 | m.weight.data.fill_(1)
172 | m.bias.data.zero_()
173 | elif isinstance(m, nn.Linear):
174 | n = m.weight.size(1)
175 | m.weight.data.normal_(0, 0.01)
176 | m.bias.data.zero_()
177 |
178 |
179 | def mobilenetv2_T_w(T, W, feature_dim=100):
180 | model = MobileNetV2(T=T, feature_dim=feature_dim, width_mult=W)
181 | return model
182 |
183 |
184 | def mobile_half(num_classes):
185 | return mobilenetv2_T_w(6, 0.5, num_classes)
186 |
187 |
188 | if __name__ == '__main__':
189 | x = torch.randn(2, 3, 32, 32)
190 |
191 | net = mobile_half(100)
192 |
193 | feats, logit = net(x, is_feat=True, preact=True)
194 | for f in feats:
195 | print(f.shape, f.min().item())
196 | print(logit.shape)
197 |
198 | for m in net.get_bn_before_relu():
199 | if isinstance(m, nn.BatchNorm2d):
200 | print('pass')
201 | else:
202 | print('warning')
203 |
204 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | '''Resnet for cifar dataset.
4 | Ported form
5 | https://github.com/facebook/fb.resnet.torch
6 | and
7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
8 | (c) YANG, Wei
9 | '''
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import math
13 |
14 |
15 | __all__ = ['resnet']
16 |
17 |
18 | def conv3x3(in_planes, out_planes, stride=1):
19 | """3x3 convolution with padding"""
20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
21 | padding=1, bias=False)
22 |
23 |
24 | class BasicBlock(nn.Module):
25 | expansion = 1
26 |
27 | def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False):
28 | super(BasicBlock, self).__init__()
29 | self.is_last = is_last
30 | self.conv1 = conv3x3(inplanes, planes, stride)
31 | self.bn1 = nn.BatchNorm2d(planes)
32 | self.relu = nn.ReLU(inplace=True)
33 | self.conv2 = conv3x3(planes, planes)
34 | self.bn2 = nn.BatchNorm2d(planes)
35 | self.downsample = downsample
36 | self.stride = stride
37 |
38 | def forward(self, x):
39 | residual = x
40 |
41 | out = self.conv1(x)
42 | out = self.bn1(out)
43 | out = self.relu(out)
44 |
45 | out = self.conv2(out)
46 | out = self.bn2(out)
47 |
48 | if self.downsample is not None:
49 | residual = self.downsample(x)
50 |
51 | out += residual
52 | preact = out
53 | out = F.relu(out)
54 | if self.is_last:
55 | return out, preact
56 | else:
57 | return out
58 |
59 |
60 | class Bottleneck(nn.Module):
61 | expansion = 4
62 |
63 | def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False):
64 | super(Bottleneck, self).__init__()
65 | self.is_last = is_last
66 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
67 | self.bn1 = nn.BatchNorm2d(planes)
68 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
69 | padding=1, bias=False)
70 | self.bn2 = nn.BatchNorm2d(planes)
71 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
72 | self.bn3 = nn.BatchNorm2d(planes * 4)
73 | self.relu = nn.ReLU(inplace=True)
74 | self.downsample = downsample
75 | self.stride = stride
76 |
77 | def forward(self, x):
78 | residual = x
79 |
80 | out = self.conv1(x)
81 | out = self.bn1(out)
82 | out = self.relu(out)
83 |
84 | out = self.conv2(out)
85 | out = self.bn2(out)
86 | out = self.relu(out)
87 |
88 | out = self.conv3(out)
89 | out = self.bn3(out)
90 |
91 | if self.downsample is not None:
92 | residual = self.downsample(x)
93 |
94 | out += residual
95 | preact = out
96 | out = F.relu(out)
97 | if self.is_last:
98 | return out, preact
99 | else:
100 | return out
101 |
102 |
103 | class ResNet(nn.Module):
104 |
105 | def __init__(self, depth, num_filters, block_name='BasicBlock', num_classes=10):
106 | super(ResNet, self).__init__()
107 | # Model type specifies number of layers for CIFAR-10 model
108 | if block_name.lower() == 'basicblock':
109 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202'
110 | n = (depth - 2) // 6
111 | block = BasicBlock
112 | elif block_name.lower() == 'bottleneck':
113 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199'
114 | n = (depth - 2) // 9
115 | block = Bottleneck
116 | else:
117 | raise ValueError('block_name shoule be Basicblock or Bottleneck')
118 |
119 | self.inplanes = num_filters[0]
120 | self.conv1 = nn.Conv2d(3, num_filters[0], kernel_size=3, padding=1,
121 | bias=False)
122 | self.bn1 = nn.BatchNorm2d(num_filters[0])
123 | self.relu = nn.ReLU(inplace=True)
124 | self.layer1 = self._make_layer(block, num_filters[1], n)
125 | self.layer2 = self._make_layer(block, num_filters[2], n, stride=2)
126 | self.layer3 = self._make_layer(block, num_filters[3], n, stride=2)
127 | self.avgpool = nn.AvgPool2d(8)
128 | self.fc = nn.Linear(num_filters[3] * block.expansion, num_classes)
129 |
130 | for m in self.modules():
131 | if isinstance(m, nn.Conv2d):
132 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
133 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
134 | nn.init.constant_(m.weight, 1)
135 | nn.init.constant_(m.bias, 0)
136 |
137 | def _make_layer(self, block, planes, blocks, stride=1):
138 | downsample = None
139 | if stride != 1 or self.inplanes != planes * block.expansion:
140 | downsample = nn.Sequential(
141 | nn.Conv2d(self.inplanes, planes * block.expansion,
142 | kernel_size=1, stride=stride, bias=False),
143 | nn.BatchNorm2d(planes * block.expansion),
144 | )
145 |
146 | layers = list([])
147 | layers.append(block(self.inplanes, planes, stride, downsample, is_last=(blocks == 1)))
148 | self.inplanes = planes * block.expansion
149 | for i in range(1, blocks):
150 | layers.append(block(self.inplanes, planes, is_last=(i == blocks-1)))
151 |
152 | return nn.Sequential(*layers)
153 |
154 | def get_feat_modules(self):
155 | feat_m = nn.ModuleList([])
156 | feat_m.append(self.conv1)
157 | feat_m.append(self.bn1)
158 | feat_m.append(self.relu)
159 | feat_m.append(self.layer1)
160 | feat_m.append(self.layer2)
161 | feat_m.append(self.layer3)
162 | return feat_m
163 |
164 | def get_bn_before_relu(self):
165 | if isinstance(self.layer1[0], Bottleneck):
166 | bn1 = self.layer1[-1].bn3
167 | bn2 = self.layer2[-1].bn3
168 | bn3 = self.layer3[-1].bn3
169 | elif isinstance(self.layer1[0], BasicBlock):
170 | bn1 = self.layer1[-1].bn2
171 | bn2 = self.layer2[-1].bn2
172 | bn3 = self.layer3[-1].bn2
173 | else:
174 | raise NotImplementedError('ResNet unknown block error !!!')
175 |
176 | return [bn1, bn2, bn3]
177 |
178 | def forward(self, x, is_feat=False, preact=False):
179 | x = self.conv1(x)
180 | x = self.bn1(x)
181 | x = self.relu(x) # 32x32
182 | f0 = x
183 |
184 | x, f1_pre = self.layer1(x) # 32x32
185 | f1 = x
186 | x, f2_pre = self.layer2(x) # 16x16
187 | f2 = x
188 | x, f3_pre = self.layer3(x) # 8x8
189 | f3 = x
190 |
191 | x = self.avgpool(x)
192 | x = x.view(x.size(0), -1)
193 | f4 = x
194 | x = self.fc(x)
195 |
196 | if is_feat:
197 | if preact:
198 | return [f0, f1_pre, f2_pre, f3_pre, f4], x
199 | else:
200 | return [f0, f1, f2, f3, f4], x
201 | else:
202 | return x
203 |
204 |
205 | def resnet8(**kwargs):
206 | return ResNet(8, [16, 16, 32, 64], 'basicblock', **kwargs)
207 |
208 |
209 | def resnet14(**kwargs):
210 | return ResNet(14, [16, 16, 32, 64], 'basicblock', **kwargs)
211 |
212 | def resnet20(**kwargs):
213 | return ResNet(20, [16, 16, 32, 64], 'basicblock', **kwargs)
214 |
215 |
216 | def resnet14x05(**kwargs):
217 | return ResNet(14, [8, 8, 16, 32], 'basicblock', **kwargs)
218 |
219 | def resnet20x05(**kwargs):
220 | return ResNet(20, [8, 8, 16, 32], 'basicblock', **kwargs)
221 |
222 | def resnet20x0375(**kwargs):
223 | return ResNet(20, [6, 6, 12, 24], 'basicblock', **kwargs)
224 |
225 |
226 |
227 | def resnet32(**kwargs):
228 | return ResNet(32, [16, 16, 32, 64], 'basicblock', **kwargs)
229 |
230 |
231 | def resnet44(**kwargs):
232 | return ResNet(44, [16, 16, 32, 64], 'basicblock', **kwargs)
233 |
234 |
235 | def resnet56(**kwargs):
236 | return ResNet(56, [16, 16, 32, 64], 'basicblock', **kwargs)
237 |
238 |
239 | def resnet110(**kwargs):
240 | return ResNet(110, [16, 16, 32, 64], 'basicblock', **kwargs)
241 |
242 |
243 | def resnet8x4(**kwargs):
244 | return ResNet(8, [32, 64, 128, 256], 'basicblock', **kwargs)
245 |
246 |
247 | def resnet32x4(**kwargs):
248 | return ResNet(32, [32, 64, 128, 256], 'basicblock', **kwargs)
249 |
250 |
251 | if __name__ == '__main__':
252 | import torch
253 |
254 | x = torch.randn(2, 3, 32, 32)
255 | net = resnet8x4(num_classes=20)
256 | feats, logit = net(x, is_feat=True, preact=True)
257 |
258 | for f in feats:
259 | print(f.shape, f.min().item())
260 | print(logit.shape)
261 |
262 | for m in net.get_bn_before_relu():
263 | if isinstance(m, nn.BatchNorm2d):
264 | print('pass')
265 | else:
266 | print('warning')
267 |
--------------------------------------------------------------------------------
/models/resnetv2.py:
--------------------------------------------------------------------------------
1 | '''ResNet in PyTorch.
2 | For Pre-activation ResNet, see 'preact_resnet.py'.
3 | Reference:
4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
6 | '''
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class BasicBlock(nn.Module):
13 | expansion = 1
14 |
15 | def __init__(self, in_planes, planes, stride=1, is_last=False):
16 | super(BasicBlock, self).__init__()
17 | self.is_last = is_last
18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
19 | self.bn1 = nn.BatchNorm2d(planes)
20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
21 | self.bn2 = nn.BatchNorm2d(planes)
22 |
23 | self.shortcut = nn.Sequential()
24 | if stride != 1 or in_planes != self.expansion * planes:
25 | self.shortcut = nn.Sequential(
26 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
27 | nn.BatchNorm2d(self.expansion * planes)
28 | )
29 |
30 | def forward(self, x):
31 | out = F.relu(self.bn1(self.conv1(x)))
32 | out = self.bn2(self.conv2(out))
33 | out += self.shortcut(x)
34 | preact = out
35 | out = F.relu(out)
36 | if self.is_last:
37 | return out, preact
38 | else:
39 | return out
40 |
41 |
42 | class Bottleneck(nn.Module):
43 | expansion = 4
44 |
45 | def __init__(self, in_planes, planes, stride=1, is_last=False):
46 | super(Bottleneck, self).__init__()
47 | self.is_last = is_last
48 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
49 | self.bn1 = nn.BatchNorm2d(planes)
50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
51 | self.bn2 = nn.BatchNorm2d(planes)
52 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
53 | self.bn3 = nn.BatchNorm2d(self.expansion * planes)
54 |
55 | self.shortcut = nn.Sequential()
56 | if stride != 1 or in_planes != self.expansion * planes:
57 | self.shortcut = nn.Sequential(
58 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
59 | nn.BatchNorm2d(self.expansion * planes)
60 | )
61 |
62 | def forward(self, x):
63 | out = F.relu(self.bn1(self.conv1(x)))
64 | out = F.relu(self.bn2(self.conv2(out)))
65 | out = self.bn3(self.conv3(out))
66 | out += self.shortcut(x)
67 | preact = out
68 | out = F.relu(out)
69 | if self.is_last:
70 | return out, preact
71 | else:
72 | return out
73 |
74 |
75 | class ResNet(nn.Module):
76 | def __init__(self, block, num_blocks, num_classes=10, zero_init_residual=False):
77 | super(ResNet, self).__init__()
78 | self.in_planes = 64
79 |
80 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
81 | self.bn1 = nn.BatchNorm2d(64)
82 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
83 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
84 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
85 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
86 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
87 | self.linear = nn.Linear(512 * block.expansion, num_classes)
88 |
89 | for m in self.modules():
90 | if isinstance(m, nn.Conv2d):
91 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
92 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
93 | nn.init.constant_(m.weight, 1)
94 | nn.init.constant_(m.bias, 0)
95 |
96 | # Zero-initialize the last BN in each residual branch,
97 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
98 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
99 | if zero_init_residual:
100 | for m in self.modules():
101 | if isinstance(m, Bottleneck):
102 | nn.init.constant_(m.bn3.weight, 0)
103 | elif isinstance(m, BasicBlock):
104 | nn.init.constant_(m.bn2.weight, 0)
105 |
106 | def get_feat_modules(self):
107 | feat_m = nn.ModuleList([])
108 | feat_m.append(self.conv1)
109 | feat_m.append(self.bn1)
110 | feat_m.append(self.layer1)
111 | feat_m.append(self.layer2)
112 | feat_m.append(self.layer3)
113 | feat_m.append(self.layer4)
114 | return feat_m
115 |
116 | def get_bn_before_relu(self):
117 | if isinstance(self.layer1[0], Bottleneck):
118 | bn1 = self.layer1[-1].bn3
119 | bn2 = self.layer2[-1].bn3
120 | bn3 = self.layer3[-1].bn3
121 | bn4 = self.layer4[-1].bn3
122 | elif isinstance(self.layer1[0], BasicBlock):
123 | bn1 = self.layer1[-1].bn2
124 | bn2 = self.layer2[-1].bn2
125 | bn3 = self.layer3[-1].bn2
126 | bn4 = self.layer4[-1].bn2
127 | else:
128 | raise NotImplementedError('ResNet unknown block error !!!')
129 |
130 | return [bn1, bn2, bn3, bn4]
131 |
132 | def _make_layer(self, block, planes, num_blocks, stride):
133 | strides = [stride] + [1] * (num_blocks - 1)
134 | layers = []
135 | for i in range(num_blocks):
136 | stride = strides[i]
137 | layers.append(block(self.in_planes, planes, stride, i == num_blocks - 1))
138 | self.in_planes = planes * block.expansion
139 | return nn.Sequential(*layers)
140 |
141 | def forward(self, x, is_feat=False, preact=False):
142 | out = F.relu(self.bn1(self.conv1(x)))
143 | f0 = out
144 | out, f1_pre = self.layer1(out)
145 | f1 = out
146 | out, f2_pre = self.layer2(out)
147 | f2 = out
148 | out, f3_pre = self.layer3(out)
149 | f3 = out
150 | out, f4_pre = self.layer4(out)
151 | f4 = out
152 | out = self.avgpool(out)
153 | out = out.view(out.size(0), -1)
154 | f5 = out
155 | out = self.linear(out)
156 | if is_feat:
157 | if preact:
158 | return [[f0, f1_pre, f2_pre, f3_pre, f4_pre, f5], out]
159 | else:
160 | return [f0, f1, f2, f3, f4, f5], out
161 | else:
162 | return out
163 |
164 |
165 | def ResNet18(**kwargs):
166 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
167 |
168 |
169 | def ResNet34(**kwargs):
170 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
171 |
172 |
173 | def ResNet50(**kwargs):
174 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
175 |
176 |
177 | def ResNet101(**kwargs):
178 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
179 |
180 |
181 | def ResNet152(**kwargs):
182 | return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
183 |
184 |
185 | if __name__ == '__main__':
186 | net = ResNet18(num_classes=100)
187 | x = torch.randn(2, 3, 32, 32)
188 | feats, logit = net(x, is_feat=True, preact=True)
189 |
190 | for f in feats:
191 | print(f.shape, f.min().item())
192 | print(logit.shape)
193 |
194 | for m in net.get_bn_before_relu():
195 | if isinstance(m, nn.BatchNorm2d):
196 | print('pass')
197 | else:
198 | print('warning')
199 |
--------------------------------------------------------------------------------
/models/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch.nn as nn
4 | import math
5 |
6 |
7 | class Paraphraser(nn.Module):
8 | """Paraphrasing Complex Network: Network Compression via Factor Transfer"""
9 | def __init__(self, t_shape, k=0.5, use_bn=False):
10 | super(Paraphraser, self).__init__()
11 | in_channel = t_shape[1]
12 | out_channel = int(t_shape[1] * k)
13 | self.encoder = nn.Sequential(
14 | nn.Conv2d(in_channel, in_channel, 3, 1, 1),
15 | nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(),
16 | nn.LeakyReLU(0.1, inplace=True),
17 | nn.Conv2d(in_channel, out_channel, 3, 1, 1),
18 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(),
19 | nn.LeakyReLU(0.1, inplace=True),
20 | nn.Conv2d(out_channel, out_channel, 3, 1, 1),
21 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(),
22 | nn.LeakyReLU(0.1, inplace=True),
23 | )
24 | self.decoder = nn.Sequential(
25 | nn.ConvTranspose2d(out_channel, out_channel, 3, 1, 1),
26 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(),
27 | nn.LeakyReLU(0.1, inplace=True),
28 | nn.ConvTranspose2d(out_channel, in_channel, 3, 1, 1),
29 | nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(),
30 | nn.LeakyReLU(0.1, inplace=True),
31 | nn.ConvTranspose2d(in_channel, in_channel, 3, 1, 1),
32 | nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(),
33 | nn.LeakyReLU(0.1, inplace=True),
34 | )
35 |
36 | def forward(self, f_s, is_factor=False):
37 | factor = self.encoder(f_s)
38 | if is_factor:
39 | return factor
40 | rec = self.decoder(factor)
41 | return factor, rec
42 |
43 |
44 | class Translator(nn.Module):
45 | def __init__(self, s_shape, t_shape, k=0.5, use_bn=True):
46 | super(Translator, self).__init__()
47 | in_channel = s_shape[1]
48 | out_channel = int(t_shape[1] * k)
49 | self.encoder = nn.Sequential(
50 | nn.Conv2d(in_channel, in_channel, 3, 1, 1),
51 | nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(),
52 | nn.LeakyReLU(0.1, inplace=True),
53 | nn.Conv2d(in_channel, out_channel, 3, 1, 1),
54 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(),
55 | nn.LeakyReLU(0.1, inplace=True),
56 | nn.Conv2d(out_channel, out_channel, 3, 1, 1),
57 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(),
58 | nn.LeakyReLU(0.1, inplace=True),
59 | )
60 |
61 | def forward(self, f_s):
62 | return self.encoder(f_s)
63 |
64 |
65 | class Connector(nn.Module):
66 | """Connect for Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons"""
67 | def __init__(self, s_shapes, t_shapes):
68 | super(Connector, self).__init__()
69 | self.s_shapes = s_shapes
70 | self.t_shapes = t_shapes
71 |
72 | self.connectors = nn.ModuleList(self._make_conenctors(s_shapes, t_shapes))
73 |
74 | @staticmethod
75 | def _make_conenctors(s_shapes, t_shapes):
76 | assert len(s_shapes) == len(t_shapes), 'unequal length of feat list'
77 | connectors = []
78 | for s, t in zip(s_shapes, t_shapes):
79 | if s[1] == t[1] and s[2] == t[2]:
80 | connectors.append(nn.Sequential())
81 | else:
82 | connectors.append(ConvReg(s, t, use_relu=False))
83 | return connectors
84 |
85 | def forward(self, g_s):
86 | out = []
87 | for i in range(len(g_s)):
88 | out.append(self.connectors[i](g_s[i]))
89 |
90 | return out
91 |
92 |
93 | class ConnectorV2(nn.Module):
94 | """A Comprehensive Overhaul of Feature Distillation (ICCV 2019)"""
95 | def __init__(self, s_shapes, t_shapes):
96 | super(ConnectorV2, self).__init__()
97 | self.s_shapes = s_shapes
98 | self.t_shapes = t_shapes
99 |
100 | self.connectors = nn.ModuleList(self._make_conenctors(s_shapes, t_shapes))
101 |
102 | def _make_conenctors(self, s_shapes, t_shapes):
103 | assert len(s_shapes) == len(t_shapes), 'unequal length of feat list'
104 | t_channels = [t[1] for t in t_shapes]
105 | s_channels = [s[1] for s in s_shapes]
106 | connectors = nn.ModuleList([self._build_feature_connector(t, s)
107 | for t, s in zip(t_channels, s_channels)])
108 | return connectors
109 |
110 | @staticmethod
111 | def _build_feature_connector(t_channel, s_channel):
112 | C = [nn.Conv2d(s_channel, t_channel, kernel_size=1, stride=1, padding=0, bias=False),
113 | nn.BatchNorm2d(t_channel)]
114 | for m in C:
115 | if isinstance(m, nn.Conv2d):
116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
117 | m.weight.data.normal_(0, math.sqrt(2. / n))
118 | elif isinstance(m, nn.BatchNorm2d):
119 | m.weight.data.fill_(1)
120 | m.bias.data.zero_()
121 | return nn.Sequential(*C)
122 |
123 | def forward(self, g_s):
124 | out = []
125 | for i in range(len(g_s)):
126 | out.append(self.connectors[i](g_s[i]))
127 |
128 | return out
129 |
130 |
131 | class ConvReg(nn.Module):
132 | """Convolutional regression for FitNet"""
133 | def __init__(self, s_shape, t_shape, use_relu=True):
134 | super(ConvReg, self).__init__()
135 | self.use_relu = use_relu
136 | s_N, s_C, s_H, s_W = s_shape
137 | t_N, t_C, t_H, t_W = t_shape
138 | if s_H == 2 * t_H:
139 | self.conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1)
140 | elif s_H * 2 == t_H:
141 | self.conv = nn.ConvTranspose2d(s_C, t_C, kernel_size=4, stride=2, padding=1)
142 | elif s_H >= t_H:
143 | self.conv = nn.Conv2d(s_C, t_C, kernel_size=(1+s_H-t_H, 1+s_W-t_W))
144 | else:
145 | raise NotImplemented('student size {}, teacher size {}'.format(s_H, t_H))
146 | self.bn = nn.BatchNorm2d(t_C)
147 | self.relu = nn.ReLU(inplace=True)
148 |
149 | def forward(self, x):
150 | x = self.conv(x)
151 | if self.use_relu:
152 | return self.relu(self.bn(x))
153 | else:
154 | return self.bn(x)
155 |
156 |
157 | class Regress(nn.Module):
158 | """Simple Linear Regression for hints"""
159 | def __init__(self, dim_in=1024, dim_out=1024):
160 | super(Regress, self).__init__()
161 | self.linear = nn.Linear(dim_in, dim_out)
162 | self.relu = nn.ReLU(inplace=True)
163 |
164 | def forward(self, x):
165 | x = x.view(x.shape[0], -1)
166 | x = self.linear(x)
167 | x = self.relu(x)
168 | return x
169 |
170 |
171 | class Embed(nn.Module):
172 | """Embedding module"""
173 | def __init__(self, dim_in=1024, dim_out=128):
174 | super(Embed, self).__init__()
175 | self.linear = nn.Linear(dim_in, dim_out)
176 | self.l2norm = Normalize(2)
177 |
178 | def forward(self, x):
179 | x = x.view(x.shape[0], -1)
180 | x = self.linear(x)
181 | x = self.l2norm(x)
182 | return x
183 |
184 |
185 | class LinearEmbed(nn.Module):
186 | """Linear Embedding"""
187 | def __init__(self, dim_in=1024, dim_out=128):
188 | super(LinearEmbed, self).__init__()
189 | self.linear = nn.Linear(dim_in, dim_out)
190 |
191 | def forward(self, x):
192 | x = x.view(x.shape[0], -1)
193 | x = self.linear(x)
194 | return x
195 |
196 |
197 | class MLPEmbed(nn.Module):
198 | """non-linear embed by MLP"""
199 | def __init__(self, dim_in=1024, dim_out=128):
200 | super(MLPEmbed, self).__init__()
201 | self.linear1 = nn.Linear(dim_in, 2 * dim_out)
202 | self.relu = nn.ReLU(inplace=True)
203 | self.linear2 = nn.Linear(2 * dim_out, dim_out)
204 | self.l2norm = Normalize(2)
205 |
206 | def forward(self, x):
207 | x = x.view(x.shape[0], -1)
208 | x = self.relu(self.linear1(x))
209 | x = self.l2norm(self.linear2(x))
210 | return x
211 |
212 |
213 | class Normalize(nn.Module):
214 | """normalization layer"""
215 | def __init__(self, power=2):
216 | super(Normalize, self).__init__()
217 | self.power = power
218 |
219 | def forward(self, x):
220 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
221 | out = x.div(norm)
222 | return out
223 |
224 |
225 | class Flatten(nn.Module):
226 | """flatten module"""
227 | def __init__(self):
228 | super(Flatten, self).__init__()
229 |
230 | def forward(self, feat):
231 | return feat.view(feat.size(0), -1)
232 |
233 |
234 | class PoolEmbed(nn.Module):
235 | """pool and embed"""
236 | def __init__(self, layer=0, dim_out=128, pool_type='avg'):
237 | super().__init__()
238 | if layer == 0:
239 | pool_size = 8
240 | nChannels = 16
241 | elif layer == 1:
242 | pool_size = 8
243 | nChannels = 16
244 | elif layer == 2:
245 | pool_size = 6
246 | nChannels = 32
247 | elif layer == 3:
248 | pool_size = 4
249 | nChannels = 64
250 | elif layer == 4:
251 | pool_size = 1
252 | nChannels = 64
253 | else:
254 | raise NotImplementedError('layer not supported: {}'.format(layer))
255 |
256 | self.embed = nn.Sequential()
257 | if layer <= 3:
258 | if pool_type == 'max':
259 | self.embed.add_module('MaxPool', nn.AdaptiveMaxPool2d((pool_size, pool_size)))
260 | elif pool_type == 'avg':
261 | self.embed.add_module('AvgPool', nn.AdaptiveAvgPool2d((pool_size, pool_size)))
262 |
263 | self.embed.add_module('Flatten', Flatten())
264 | self.embed.add_module('Linear', nn.Linear(nChannels*pool_size*pool_size, dim_out))
265 | self.embed.add_module('Normalize', Normalize(2))
266 |
267 | def forward(self, x):
268 | return self.embed(x)
269 |
270 |
271 | if __name__ == '__main__':
272 | import torch
273 |
274 | g_s = [
275 | torch.randn(2, 16, 16, 16),
276 | torch.randn(2, 32, 8, 8),
277 | torch.randn(2, 64, 4, 4),
278 | ]
279 | g_t = [
280 | torch.randn(2, 32, 16, 16),
281 | torch.randn(2, 64, 8, 8),
282 | torch.randn(2, 128, 4, 4),
283 | ]
284 | s_shapes = [s.shape for s in g_s]
285 | t_shapes = [t.shape for t in g_t]
286 |
287 | net = ConnectorV2(s_shapes, t_shapes)
288 | out = net(g_s)
289 | for f in out:
290 | print(f.shape)
291 |
--------------------------------------------------------------------------------
/models/vgg.py:
--------------------------------------------------------------------------------
1 | '''VGG for CIFAR10. FC layers are removed.
2 | (c) YANG, Wei
3 | '''
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import math
7 |
8 |
9 | __all__ = [
10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
11 | 'vgg19_bn', 'vgg19',
12 | ]
13 |
14 |
15 | model_urls = {
16 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
17 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
18 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
19 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
20 | }
21 |
22 |
23 | class VGG(nn.Module):
24 |
25 | def __init__(self, cfg, batch_norm=False, num_classes=1000):
26 | super(VGG, self).__init__()
27 | self.block0 = self._make_layers(cfg[0], batch_norm, 3)
28 | self.block1 = self._make_layers(cfg[1], batch_norm, cfg[0][-1])
29 | self.block2 = self._make_layers(cfg[2], batch_norm, cfg[1][-1])
30 | self.block3 = self._make_layers(cfg[3], batch_norm, cfg[2][-1])
31 | self.block4 = self._make_layers(cfg[4], batch_norm, cfg[3][-1])
32 |
33 | self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2)
34 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
35 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
36 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
37 | self.pool4 = nn.AdaptiveAvgPool2d((1, 1))
38 | # self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
39 |
40 | self.classifier = nn.Linear(512, num_classes)
41 | self._initialize_weights()
42 |
43 | def get_feat_modules(self):
44 | feat_m = nn.ModuleList([])
45 | feat_m.append(self.block0)
46 | feat_m.append(self.pool0)
47 | feat_m.append(self.block1)
48 | feat_m.append(self.pool1)
49 | feat_m.append(self.block2)
50 | feat_m.append(self.pool2)
51 | feat_m.append(self.block3)
52 | feat_m.append(self.pool3)
53 | feat_m.append(self.block4)
54 | feat_m.append(self.pool4)
55 | return feat_m
56 |
57 | def get_bn_before_relu(self):
58 | bn1 = self.block1[-1]
59 | bn2 = self.block2[-1]
60 | bn3 = self.block3[-1]
61 | bn4 = self.block4[-1]
62 | return [bn1, bn2, bn3, bn4]
63 |
64 | def forward(self, x, is_feat=False, preact=False):
65 | h = x.shape[2]
66 | x = F.relu(self.block0(x))
67 | f0 = x
68 | x = self.pool0(x)
69 | x = self.block1(x)
70 | f1_pre = x
71 | x = F.relu(x)
72 | f1 = x
73 | x = self.pool1(x)
74 | x = self.block2(x)
75 | f2_pre = x
76 | x = F.relu(x)
77 | f2 = x
78 | x = self.pool2(x)
79 | x = self.block3(x)
80 | f3_pre = x
81 | x = F.relu(x)
82 | f3 = x
83 | if h == 64:
84 | x = self.pool3(x)
85 | x = self.block4(x)
86 | f4_pre = x
87 | x = F.relu(x)
88 | f4 = x
89 | x = self.pool4(x)
90 | x = x.view(x.size(0), -1)
91 | f5 = x
92 | x = self.classifier(x)
93 |
94 | if is_feat:
95 | if preact:
96 | return [f0, f1_pre, f2_pre, f3_pre, f4_pre, f5], x
97 | else:
98 | return [f0, f1, f2, f3, f4, f5], x
99 | else:
100 | return x
101 |
102 | @staticmethod
103 | def _make_layers(cfg, batch_norm=False, in_channels=3):
104 | layers = []
105 | for v in cfg:
106 | if v == 'M':
107 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
108 | else:
109 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
110 | if batch_norm:
111 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
112 | else:
113 | layers += [conv2d, nn.ReLU(inplace=True)]
114 | in_channels = v
115 | layers = layers[:-1]
116 | return nn.Sequential(*layers)
117 |
118 | def _initialize_weights(self):
119 | for m in self.modules():
120 | if isinstance(m, nn.Conv2d):
121 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
122 | m.weight.data.normal_(0, math.sqrt(2. / n))
123 | if m.bias is not None:
124 | m.bias.data.zero_()
125 | elif isinstance(m, nn.BatchNorm2d):
126 | m.weight.data.fill_(1)
127 | m.bias.data.zero_()
128 | elif isinstance(m, nn.Linear):
129 | n = m.weight.size(1)
130 | m.weight.data.normal_(0, 0.01)
131 | m.bias.data.zero_()
132 |
133 |
134 | cfg = {
135 | 'A': [[64], [128], [256, 256], [512, 512], [512, 512]],
136 | 'B': [[64, 64], [128, 128], [256, 256], [512, 512], [512, 512]],
137 | 'D': [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]],
138 | 'E': [[64, 64], [128, 128], [256, 256, 256, 256], [512, 512, 512, 512], [512, 512, 512, 512]],
139 | 'S': [[64], [128], [256], [512], [512]],
140 | }
141 |
142 |
143 | def vgg8(**kwargs):
144 | """VGG 8-layer model (configuration "S")
145 | Args:
146 | pretrained (bool): If True, returns a model pre-trained on ImageNet
147 | """
148 | model = VGG(cfg['S'], **kwargs)
149 | return model
150 |
151 |
152 | def vgg8_bn(**kwargs):
153 | """VGG 8-layer model (configuration "S")
154 | Args:
155 | pretrained (bool): If True, returns a model pre-trained on ImageNet
156 | """
157 | model = VGG(cfg['S'], batch_norm=True, **kwargs)
158 | return model
159 |
160 |
161 | def vgg11(**kwargs):
162 | """VGG 11-layer model (configuration "A")
163 | Args:
164 | pretrained (bool): If True, returns a model pre-trained on ImageNet
165 | """
166 | model = VGG(cfg['A'], **kwargs)
167 | return model
168 |
169 |
170 | def vgg11_bn(**kwargs):
171 | """VGG 11-layer model (configuration "A") with batch normalization"""
172 | model = VGG(cfg['A'], batch_norm=True, **kwargs)
173 | return model
174 |
175 |
176 | def vgg13(**kwargs):
177 | """VGG 13-layer model (configuration "B")
178 | Args:
179 | pretrained (bool): If True, returns a model pre-trained on ImageNet
180 | """
181 | model = VGG(cfg['B'], **kwargs)
182 | return model
183 |
184 |
185 | def vgg13_bn(**kwargs):
186 | """VGG 13-layer model (configuration "B") with batch normalization"""
187 | model = VGG(cfg['B'], batch_norm=True, **kwargs)
188 | return model
189 |
190 |
191 | def vgg16(**kwargs):
192 | """VGG 16-layer model (configuration "D")
193 | Args:
194 | pretrained (bool): If True, returns a model pre-trained on ImageNet
195 | """
196 | model = VGG(cfg['D'], **kwargs)
197 | return model
198 |
199 |
200 | def vgg16_bn(**kwargs):
201 | """VGG 16-layer model (configuration "D") with batch normalization"""
202 | model = VGG(cfg['D'], batch_norm=True, **kwargs)
203 | return model
204 |
205 |
206 | def vgg19(**kwargs):
207 | """VGG 19-layer model (configuration "E")
208 | Args:
209 | pretrained (bool): If True, returns a model pre-trained on ImageNet
210 | """
211 | model = VGG(cfg['E'], **kwargs)
212 | return model
213 |
214 |
215 | def vgg19_bn(**kwargs):
216 | """VGG 19-layer model (configuration 'E') with batch normalization"""
217 | model = VGG(cfg['E'], batch_norm=True, **kwargs)
218 | return model
219 |
220 |
221 | if __name__ == '__main__':
222 | import torch
223 |
224 | x = torch.randn(2, 3, 32, 32)
225 | net = vgg19_bn(num_classes=100)
226 | feats, logit = net(x, is_feat=True, preact=True)
227 |
228 | for f in feats:
229 | print(f.shape, f.min().item())
230 | print(logit.shape)
231 |
232 | for m in net.get_bn_before_relu():
233 | if isinstance(m, nn.BatchNorm2d):
234 | print('pass')
235 | else:
236 | print('warning')
237 |
--------------------------------------------------------------------------------
/models/wrn.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | """
7 | Original Author: Wei Yang
8 | """
9 |
10 | __all__ = ['wrn']
11 |
12 |
13 | class BasicBlock(nn.Module):
14 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
15 | super(BasicBlock, self).__init__()
16 | self.bn1 = nn.BatchNorm2d(in_planes)
17 | self.relu1 = nn.ReLU(inplace=True)
18 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
19 | padding=1, bias=False)
20 | self.bn2 = nn.BatchNorm2d(out_planes)
21 | self.relu2 = nn.ReLU(inplace=True)
22 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
23 | padding=1, bias=False)
24 | self.droprate = dropRate
25 | self.equalInOut = (in_planes == out_planes)
26 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
27 | padding=0, bias=False) or None
28 |
29 | def forward(self, x):
30 | if not self.equalInOut:
31 | x = self.relu1(self.bn1(x))
32 | else:
33 | out = self.relu1(self.bn1(x))
34 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
35 | if self.droprate > 0:
36 | out = F.dropout(out, p=self.droprate, training=self.training)
37 | out = self.conv2(out)
38 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
39 |
40 |
41 | class NetworkBlock(nn.Module):
42 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
43 | super(NetworkBlock, self).__init__()
44 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
45 |
46 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
47 | layers = []
48 | for i in range(nb_layers):
49 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
50 | return nn.Sequential(*layers)
51 |
52 | def forward(self, x):
53 | return self.layer(x)
54 |
55 |
56 | class WideResNet(nn.Module):
57 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
58 | super(WideResNet, self).__init__()
59 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
60 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
61 | n = (depth - 4) // 6
62 | block = BasicBlock
63 | # 1st conv before any network block
64 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
65 | padding=1, bias=False)
66 | # 1st block
67 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
68 | # 2nd block
69 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
70 | # 3rd block
71 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
72 | # global average pooling and classifier
73 | self.bn1 = nn.BatchNorm2d(nChannels[3])
74 | self.relu = nn.ReLU(inplace=True)
75 | self.fc = nn.Linear(nChannels[3], num_classes)
76 | self.nChannels = nChannels[3]
77 |
78 | for m in self.modules():
79 | if isinstance(m, nn.Conv2d):
80 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
81 | m.weight.data.normal_(0, math.sqrt(2. / n))
82 | elif isinstance(m, nn.BatchNorm2d):
83 | m.weight.data.fill_(1)
84 | m.bias.data.zero_()
85 | elif isinstance(m, nn.Linear):
86 | m.bias.data.zero_()
87 |
88 | def get_feat_modules(self):
89 | feat_m = nn.ModuleList([])
90 | feat_m.append(self.conv1)
91 | feat_m.append(self.block1)
92 | feat_m.append(self.block2)
93 | feat_m.append(self.block3)
94 | return feat_m
95 |
96 | def get_bn_before_relu(self):
97 | bn1 = self.block2.layer[0].bn1
98 | bn2 = self.block3.layer[0].bn1
99 | bn3 = self.bn1
100 |
101 | return [bn1, bn2, bn3]
102 |
103 | def forward(self, x, is_feat=False, preact=False):
104 | out = self.conv1(x)
105 | f0 = out
106 | out = self.block1(out)
107 | f1 = out
108 | out = self.block2(out)
109 | f2 = out
110 | out = self.block3(out)
111 | f3 = out
112 | out = self.relu(self.bn1(out))
113 | out = F.avg_pool2d(out, 8)
114 | out = out.view(-1, self.nChannels)
115 | f4 = out
116 | out = self.fc(out)
117 | if is_feat:
118 | if preact:
119 | f1 = self.block2.layer[0].bn1(f1)
120 | f2 = self.block3.layer[0].bn1(f2)
121 | f3 = self.bn1(f3)
122 | return [f0, f1, f2, f3, f4], out
123 | else:
124 | return out
125 |
126 |
127 | def wrn(**kwargs):
128 | """
129 | Constructs a Wide Residual Networks.
130 | """
131 | model = WideResNet(**kwargs)
132 | return model
133 |
134 |
135 | def wrn_40_2(**kwargs):
136 | model = WideResNet(depth=40, widen_factor=2, **kwargs)
137 | return model
138 |
139 |
140 | def wrn_40_1(**kwargs):
141 | model = WideResNet(depth=40, widen_factor=1, **kwargs)
142 | return model
143 |
144 |
145 | def wrn_16_2(**kwargs):
146 | model = WideResNet(depth=16, widen_factor=2, **kwargs)
147 | return model
148 |
149 |
150 | def wrn_16_1(**kwargs):
151 | model = WideResNet(depth=16, widen_factor=1, **kwargs)
152 | return model
153 |
154 |
155 | if __name__ == '__main__':
156 | import torch
157 |
158 | x = torch.randn(2, 3, 32, 32)
159 | net = wrn_40_2(num_classes=100)
160 | feats, logit = net(x, is_feat=True, preact=True)
161 |
162 | for f in feats:
163 | print(f.shape, f.min().item())
164 | print(logit.shape)
165 |
166 | for m in net.get_bn_before_relu():
167 | if isinstance(m, nn.BatchNorm2d):
168 | print('pass')
169 | else:
170 | print('warning')
171 |
--------------------------------------------------------------------------------
/student.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import argparse
4 | import time
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | from torch.optim.lr_scheduler import MultiStepLR
12 | from torch.utils.data import DataLoader
13 |
14 | import torchvision.transforms as transforms
15 | from tensorboardX import SummaryWriter
16 |
17 | from utils import AverageMeter, accuracy
18 | from wrapper import wrapper
19 | from cifar import CIFAR100
20 |
21 | from models import model_dict
22 |
23 | torch.backends.cudnn.benchmark = True
24 |
25 | parser = argparse.ArgumentParser(description='train SSKD student network.')
26 | parser.add_argument('--epoch', type=int, default=240)
27 | parser.add_argument('--t-epoch', type=int, default=60)
28 | parser.add_argument('--batch-size', type=int, default=64)
29 |
30 | parser.add_argument('--lr', type=float, default=0.05)
31 | parser.add_argument('--t-lr', type=float, default=0.05)
32 | parser.add_argument('--momentum', type=float, default=0.9)
33 | parser.add_argument('--weight-decay', type=float, default=5e-4)
34 | parser.add_argument('--gamma', type=float, default=0.1)
35 | parser.add_argument('--milestones', type=int, nargs='+', default=[150,180,210])
36 | parser.add_argument('--t-milestones', type=int, nargs='+', default=[30,45])
37 |
38 | parser.add_argument('--save-interval', type=int, default=40)
39 | parser.add_argument('--ce-weight', type=float, default=0.1) # cross-entropy
40 | parser.add_argument('--kd-weight', type=float, default=0.9) # knowledge distillation
41 | parser.add_argument('--tf-weight', type=float, default=2.7) # transformation
42 | parser.add_argument('--ss-weight', type=float, default=10.0) # self-supervision
43 |
44 | parser.add_argument('--kd-T', type=float, default=4.0) # temperature in KD
45 | parser.add_argument('--tf-T', type=float, default=4.0) # temperature in LT
46 | parser.add_argument('--ss-T', type=float, default=0.5) # temperature in SS
47 |
48 | parser.add_argument('--ratio-tf', type=float, default=1.0) # keep how many wrong predictions of LT
49 | parser.add_argument('--ratio-ss', type=float, default=0.75) # keep how many wrong predictions of SS
50 | parser.add_argument('--s-arch', type=str) # student architecture
51 | parser.add_argument('--t-path', type=str) # teacher checkpoint path
52 |
53 | parser.add_argument('--seed', type=int, default=0)
54 | parser.add_argument('--gpu-id', type=int, default=0)
55 |
56 | args = parser.parse_args()
57 | torch.manual_seed(args.seed)
58 | torch.cuda.manual_seed(args.seed)
59 | np.random.seed(args.seed)
60 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
61 |
62 |
63 | t_name = osp.abspath(args.t_path).split('/')[-1]
64 | t_arch = '_'.join(t_name.split('_')[1:-1])
65 | exp_name = 'sskd_student_{}_weight{}+{}+{}+{}_T{}+{}+{}_ratio{}+{}_seed{}_{}'.format(\
66 | args.s_arch, \
67 | args.ce_weight, args.kd_weight, args.tf_weight, args.ss_weight, \
68 | args.kd_T, args.tf_T, args.ss_T, \
69 | args.ratio_tf, args.ratio_ss, \
70 | args.seed, t_name)
71 | exp_path = './experiments/{}'.format(exp_name)
72 | os.makedirs(exp_path, exist_ok=True)
73 |
74 | transform_train = transforms.Compose([
75 | transforms.RandomCrop(32, padding=4),
76 | transforms.ToTensor(),
77 | transforms.Normalize(mean=[0.5071, 0.4866, 0.4409], std=[0.2675, 0.2565, 0.2761]),
78 | ])
79 | transform_test = transforms.Compose([
80 | transforms.ToTensor(),
81 | transforms.Normalize(mean=[0.5071, 0.4866, 0.4409], std=[0.2675, 0.2565, 0.2761]),
82 | ])
83 |
84 | trainset = CIFAR100('./data', train=True, transform=transform_train)
85 | valset = CIFAR100('./data', train=False, transform=transform_test)
86 |
87 | train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=False)
88 | val_loader = DataLoader(valset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=False)
89 |
90 | ckpt_path = osp.join(args.t_path, 'ckpt/best.pth')
91 | t_model = model_dict[t_arch](num_classes=100).cuda()
92 | state_dict = torch.load(ckpt_path)['state_dict']
93 | t_model.load_state_dict(state_dict)
94 | t_model = wrapper(module=t_model).cuda()
95 |
96 | t_optimizer = optim.SGD([{'params':t_model.backbone.parameters(), 'lr':0.0},
97 | {'params':t_model.proj_head.parameters(), 'lr':args.t_lr}],
98 | momentum=args.momentum, weight_decay=args.weight_decay)
99 | t_model.eval()
100 | t_scheduler = MultiStepLR(t_optimizer, milestones=args.t_milestones, gamma=args.gamma)
101 |
102 | logger = SummaryWriter(osp.join(exp_path, 'events'))
103 |
104 | acc_record = AverageMeter()
105 | loss_record = AverageMeter()
106 | start = time.time()
107 | for x, target in val_loader:
108 |
109 | x = x[:,0,:,:,:].cuda()
110 | target = target.cuda()
111 | with torch.no_grad():
112 | output, _, feat = t_model(x)
113 | loss = F.cross_entropy(output, target)
114 |
115 | batch_acc = accuracy(output, target, topk=(1,))[0]
116 | acc_record.update(batch_acc.item(), x.size(0))
117 | loss_record.update(loss.item(), x.size(0))
118 |
119 | run_time = time.time() - start
120 | info = 'teacher cls_acc:{:.2f}\n'.format(acc_record.avg)
121 | print(info)
122 |
123 | # train ssp_head
124 | for epoch in range(args.t_epoch):
125 |
126 | t_model.eval()
127 | loss_record = AverageMeter()
128 | acc_record = AverageMeter()
129 |
130 | start = time.time()
131 | for x, _ in train_loader:
132 |
133 | t_optimizer.zero_grad()
134 |
135 | x = x.cuda()
136 | c,h,w = x.size()[-3:]
137 | x = x.view(-1, c, h, w)
138 |
139 | _, rep, feat = t_model(x, bb_grad=False)
140 | batch = int(x.size(0) / 4)
141 | nor_index = (torch.arange(4*batch) % 4 == 0).cuda()
142 | aug_index = (torch.arange(4*batch) % 4 != 0).cuda()
143 |
144 | nor_rep = rep[nor_index]
145 | aug_rep = rep[aug_index]
146 | nor_rep = nor_rep.unsqueeze(2).expand(-1,-1,3*batch).transpose(0,2)
147 | aug_rep = aug_rep.unsqueeze(2).expand(-1,-1,1*batch)
148 | simi = F.cosine_similarity(aug_rep, nor_rep, dim=1)
149 | target = torch.arange(batch).unsqueeze(1).expand(-1,3).contiguous().view(-1).long().cuda()
150 | loss = F.cross_entropy(simi, target)
151 |
152 | loss.backward()
153 | t_optimizer.step()
154 |
155 | batch_acc = accuracy(simi, target, topk=(1,))[0]
156 | loss_record.update(loss.item(), 3*batch)
157 | acc_record.update(batch_acc.item(), 3*batch)
158 |
159 | logger.add_scalar('train/teacher_ssp_loss', loss_record.avg, epoch+1)
160 | logger.add_scalar('train/teacher_ssp_acc', acc_record.avg, epoch+1)
161 |
162 | run_time = time.time() - start
163 | info = 'teacher_train_Epoch:{:03d}/{:03d}\t run_time:{:.3f}\t ssp_loss:{:.3f}\t ssp_acc:{:.2f}\t'.format(
164 | epoch+1, args.t_epoch, run_time, loss_record.avg, acc_record.avg)
165 | print(info)
166 |
167 | t_model.eval()
168 | acc_record = AverageMeter()
169 | loss_record = AverageMeter()
170 | start = time.time()
171 | for x, _ in val_loader:
172 |
173 | x = x.cuda()
174 | c,h,w = x.size()[-3:]
175 | x = x.view(-1, c, h, w)
176 |
177 | with torch.no_grad():
178 | _, rep, feat = t_model(x)
179 | batch = int(x.size(0) / 4)
180 | nor_index = (torch.arange(4*batch) % 4 == 0).cuda()
181 | aug_index = (torch.arange(4*batch) % 4 != 0).cuda()
182 |
183 | nor_rep = rep[nor_index]
184 | aug_rep = rep[aug_index]
185 | nor_rep = nor_rep.unsqueeze(2).expand(-1,-1,3*batch).transpose(0,2)
186 | aug_rep = aug_rep.unsqueeze(2).expand(-1,-1,1*batch)
187 | simi = F.cosine_similarity(aug_rep, nor_rep, dim=1)
188 | target = torch.arange(batch).unsqueeze(1).expand(-1,3).contiguous().view(-1).long().cuda()
189 | loss = F.cross_entropy(simi, target)
190 |
191 | batch_acc = accuracy(simi, target, topk=(1,))[0]
192 | acc_record.update(batch_acc.item(),3*batch)
193 | loss_record.update(loss.item(), 3*batch)
194 |
195 | run_time = time.time() - start
196 | logger.add_scalar('val/teacher_ssp_loss', loss_record.avg, epoch+1)
197 | logger.add_scalar('val/teacher_ssp_acc', acc_record.avg, epoch+1)
198 |
199 | info = 'ssp_test_Epoch:{:03d}/{:03d}\t run_time:{:.2f}\t ssp_loss:{:.3f}\t ssp_acc:{:.2f}\n'.format(
200 | epoch+1, args.t_epoch, run_time, loss_record.avg, acc_record.avg)
201 | print(info)
202 |
203 | t_scheduler.step()
204 |
205 |
206 | name = osp.join(exp_path, 'ckpt/teacher.pth')
207 | os.makedirs(osp.dirname(name), exist_ok=True)
208 | torch.save(t_model.state_dict(), name)
209 |
210 |
211 | s_model = model_dict[args.s_arch](num_classes=100)
212 | s_model = wrapper(module=s_model).cuda()
213 | optimizer = optim.SGD(s_model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
214 | scheduler = MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma)
215 |
216 | best_acc = 0
217 | for epoch in range(args.epoch):
218 |
219 | # train
220 | s_model.train()
221 | loss1_record = AverageMeter()
222 | loss2_record = AverageMeter()
223 | loss3_record = AverageMeter()
224 | loss4_record = AverageMeter()
225 | cls_acc_record = AverageMeter()
226 | ssp_acc_record = AverageMeter()
227 |
228 | start = time.time()
229 | for x, target in train_loader:
230 |
231 | optimizer.zero_grad()
232 |
233 | c,h,w = x.size()[-3:]
234 | x = x.view(-1,c,h,w).cuda()
235 | target = target.cuda()
236 |
237 | batch = int(x.size(0) / 4)
238 | nor_index = (torch.arange(4*batch) % 4 == 0).cuda()
239 | aug_index = (torch.arange(4*batch) % 4 != 0).cuda()
240 |
241 | output, s_feat, _ = s_model(x, bb_grad=True)
242 | log_nor_output = F.log_softmax(output[nor_index] / args.kd_T, dim=1)
243 | log_aug_output = F.log_softmax(output[aug_index] / args.tf_T, dim=1)
244 | with torch.no_grad():
245 | knowledge, t_feat, _ = t_model(x)
246 | nor_knowledge = F.softmax(knowledge[nor_index] / args.kd_T, dim=1)
247 | aug_knowledge = F.softmax(knowledge[aug_index] / args.tf_T, dim=1)
248 |
249 | # error level ranking
250 | aug_target = target.unsqueeze(1).expand(-1,3).contiguous().view(-1).long().cuda()
251 | rank = torch.argsort(aug_knowledge, dim=1, descending=True)
252 | rank = torch.argmax(torch.eq(rank, aug_target.unsqueeze(1)).long(), dim=1) # groundtruth label's rank
253 | index = torch.argsort(rank)
254 | tmp = torch.nonzero(rank, as_tuple=True)[0]
255 | wrong_num = tmp.numel()
256 | correct_num = 3*batch - wrong_num
257 | wrong_keep = int(wrong_num * args.ratio_tf)
258 | index = index[:correct_num+wrong_keep]
259 | distill_index_tf = torch.sort(index)[0]
260 |
261 | s_nor_feat = s_feat[nor_index]
262 | s_aug_feat = s_feat[aug_index]
263 | s_nor_feat = s_nor_feat.unsqueeze(2).expand(-1,-1,3*batch).transpose(0,2)
264 | s_aug_feat = s_aug_feat.unsqueeze(2).expand(-1,-1,1*batch)
265 | s_simi = F.cosine_similarity(s_aug_feat, s_nor_feat, dim=1)
266 |
267 | t_nor_feat = t_feat[nor_index]
268 | t_aug_feat = t_feat[aug_index]
269 | t_nor_feat = t_nor_feat.unsqueeze(2).expand(-1,-1,3*batch).transpose(0,2)
270 | t_aug_feat = t_aug_feat.unsqueeze(2).expand(-1,-1,1*batch)
271 | t_simi = F.cosine_similarity(t_aug_feat, t_nor_feat, dim=1)
272 |
273 | t_simi = t_simi.detach()
274 | aug_target = torch.arange(batch).unsqueeze(1).expand(-1,3).contiguous().view(-1).long().cuda()
275 | rank = torch.argsort(t_simi, dim=1, descending=True)
276 | rank = torch.argmax(torch.eq(rank, aug_target.unsqueeze(1)).long(), dim=1) # groundtruth label's rank
277 | index = torch.argsort(rank)
278 | tmp = torch.nonzero(rank, as_tuple=True)[0]
279 | wrong_num = tmp.numel()
280 | correct_num = 3*batch - wrong_num
281 | wrong_keep = int(wrong_num * args.ratio_ss)
282 | index = index[:correct_num+wrong_keep]
283 | distill_index_ss = torch.sort(index)[0]
284 |
285 | log_simi = F.log_softmax(s_simi / args.ss_T, dim=1)
286 | simi_knowledge = F.softmax(t_simi / args.ss_T, dim=1)
287 |
288 | loss1 = F.cross_entropy(output[nor_index], target)
289 | loss2 = F.kl_div(log_nor_output, nor_knowledge, reduction='batchmean') * args.kd_T * args.kd_T
290 | loss3 = F.kl_div(log_aug_output[distill_index_tf], aug_knowledge[distill_index_tf], \
291 | reduction='batchmean') * args.tf_T * args.tf_T
292 | loss4 = F.kl_div(log_simi[distill_index_ss], simi_knowledge[distill_index_ss], \
293 | reduction='batchmean') * args.ss_T * args.ss_T
294 |
295 | loss = args.ce_weight * loss1 + args.kd_weight * loss2 + args.tf_weight * loss3 + args.ss_weight * loss4
296 |
297 | loss.backward()
298 | optimizer.step()
299 |
300 | cls_batch_acc = accuracy(output[nor_index], target, topk=(1,))[0]
301 | ssp_batch_acc = accuracy(s_simi, aug_target, topk=(1,))[0]
302 | loss1_record.update(loss1.item(), batch)
303 | loss2_record.update(loss2.item(), batch)
304 | loss3_record.update(loss3.item(), len(distill_index_tf))
305 | loss4_record.update(loss4.item(), len(distill_index_ss))
306 | cls_acc_record.update(cls_batch_acc.item(), batch)
307 | ssp_acc_record.update(ssp_batch_acc.item(), 3*batch)
308 |
309 | logger.add_scalar('train/ce_loss', loss1_record.avg, epoch+1)
310 | logger.add_scalar('train/kd_loss', loss2_record.avg, epoch+1)
311 | logger.add_scalar('train/tf_loss', loss3_record.avg, epoch+1)
312 | logger.add_scalar('train/ss_loss', loss4_record.avg, epoch+1)
313 | logger.add_scalar('train/cls_acc', cls_acc_record.avg, epoch+1)
314 | logger.add_scalar('train/ss_acc', ssp_acc_record.avg, epoch+1)
315 |
316 | run_time = time.time() - start
317 | info = 'student_train_Epoch:{:03d}/{:03d}\t run_time:{:.3f}\t ce_loss:{:.3f}\t kd_loss:{:.3f}\t cls_acc:{:.2f}'.format(
318 | epoch+1, args.epoch, run_time, loss1_record.avg, loss2_record.avg, cls_acc_record.avg)
319 | print(info)
320 |
321 | # cls val
322 | s_model.eval()
323 | acc_record = AverageMeter()
324 | loss_record = AverageMeter()
325 | start = time.time()
326 | for x, target in val_loader:
327 |
328 | x = x[:,0,:,:,:].cuda()
329 | target = target.cuda()
330 | with torch.no_grad():
331 | output, _, feat = s_model(x)
332 | loss = F.cross_entropy(output, target)
333 |
334 | batch_acc = accuracy(output, target, topk=(1,))[0]
335 | acc_record.update(batch_acc.item(), x.size(0))
336 | loss_record.update(loss.item(), x.size(0))
337 |
338 | run_time = time.time() - start
339 | logger.add_scalar('val/ce_loss', loss_record.avg, epoch+1)
340 | logger.add_scalar('val/cls_acc', acc_record.avg, epoch+1)
341 |
342 | info = 'student_test_Epoch:{:03d}/{:03d}\t run_time:{:.2f}\t cls_acc:{:.2f}\n'.format(
343 | epoch+1, args.epoch, run_time, acc_record.avg)
344 | print(info)
345 |
346 | if acc_record.avg > best_acc:
347 | best_acc = acc_record.avg
348 | state_dict = dict(epoch=epoch+1, state_dict=s_model.state_dict(), best_acc=best_acc)
349 | name = osp.join(exp_path, 'ckpt/student_best.pth')
350 | os.makedirs(osp.dirname(name), exist_ok=True)
351 | torch.save(state_dict, name)
352 |
353 | scheduler.step()
354 |
355 |
--------------------------------------------------------------------------------
/teacher.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import argparse
4 | import time
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | from torch.optim.lr_scheduler import MultiStepLR
12 | from torch.utils.data import DataLoader
13 |
14 | import torchvision.transforms as transforms
15 | from torchvision.datasets import CIFAR100
16 | from tensorboardX import SummaryWriter
17 |
18 | from utils import AverageMeter, accuracy
19 | from models import model_dict
20 |
21 | torch.backends.cudnn.benchmark = True
22 |
23 | parser = argparse.ArgumentParser(description='train teacher network.')
24 | parser.add_argument('--epoch', type=int, default=240)
25 | parser.add_argument('--batch-size', type=int, default=64)
26 |
27 | parser.add_argument('--lr', type=float, default=0.05)
28 | parser.add_argument('--momentum', type=float, default=0.9)
29 | parser.add_argument('--weight-decay', type=float, default=5e-4)
30 | parser.add_argument('--gamma', type=float, default=0.1)
31 | parser.add_argument('--milestones', type=int, nargs='+', default=[150,180,210])
32 |
33 | parser.add_argument('--save-interval', type=int, default=40)
34 | parser.add_argument('--arch', type=str)
35 | parser.add_argument('--seed', type=int, default=0)
36 | parser.add_argument('--gpu-id', type=int, default=0)
37 |
38 | args = parser.parse_args()
39 | torch.manual_seed(args.seed)
40 | torch.cuda.manual_seed(args.seed)
41 | np.random.seed(args.seed)
42 |
43 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
44 |
45 | exp_name = 'teacher_{}_seed{}'.format(args.arch, args.seed)
46 | exp_path = './experiments/{}'.format(exp_name)
47 | os.makedirs(exp_path, exist_ok=True)
48 |
49 | transform_train = transforms.Compose([
50 | transforms.RandomCrop(32, padding=4),
51 | transforms.RandomHorizontalFlip(),
52 | transforms.ToTensor(),
53 | transforms.Normalize(mean=[0.5071, 0.4866, 0.4409], std=[0.2675, 0.2565, 0.2761]),
54 | ])
55 | transform_test = transforms.Compose([
56 | transforms.ToTensor(),
57 | transforms.Normalize(mean=[0.5071, 0.4866, 0.4409], std=[0.2675, 0.2565, 0.2761]),
58 | ])
59 |
60 | trainset = CIFAR100('./data', train=True, transform=transform_train, download=True)
61 | valset = CIFAR100('./data', train=False, transform=transform_test, download=True)
62 | train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=False)
63 | val_loader = DataLoader(valset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=False)
64 |
65 | model = model_dict[args.arch](num_classes=100).cuda()
66 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
67 | scheduler = MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma)
68 |
69 | logger = SummaryWriter(osp.join(exp_path, 'events'))
70 | best_acc = -1
71 | for epoch in range(args.epoch):
72 |
73 | model.train()
74 | loss_record = AverageMeter()
75 | acc_record = AverageMeter()
76 |
77 | start = time.time()
78 | for x, target in train_loader:
79 |
80 | optimizer.zero_grad()
81 | x = x.cuda()
82 | target = target.cuda()
83 |
84 | output = model(x)
85 | loss = F.cross_entropy(output, target)
86 |
87 | loss.backward()
88 | optimizer.step()
89 |
90 | batch_acc = accuracy(output, target, topk=(1,))[0]
91 | loss_record.update(loss.item(), x.size(0))
92 | acc_record.update(batch_acc.item(), x.size(0))
93 |
94 | logger.add_scalar('train/cls_loss', loss_record.avg, epoch+1)
95 | logger.add_scalar('train/cls_acc', acc_record.avg, epoch+1)
96 |
97 | run_time = time.time() - start
98 |
99 | info = 'train_Epoch:{:03d}/{:03d}\t run_time:{:.3f}\t cls_loss:{:.3f}\t cls_acc:{:.2f}\t'.format(
100 | epoch+1, args.epoch, run_time, loss_record.avg, acc_record.avg)
101 | print(info)
102 |
103 | model.eval()
104 | acc_record = AverageMeter()
105 | loss_record = AverageMeter()
106 | start = time.time()
107 | for x, target in val_loader:
108 |
109 | x = x.cuda()
110 | target = target.cuda()
111 | with torch.no_grad():
112 | output = model(x)
113 | loss = F.cross_entropy(output, target)
114 |
115 | batch_acc = accuracy(output, target, topk=(1,))[0]
116 | loss_record.update(loss.item(), x.size(0))
117 | acc_record.update(batch_acc.item(), x.size(0))
118 |
119 | run_time = time.time() - start
120 |
121 | logger.add_scalar('val/cls_loss', loss_record.avg, epoch+1)
122 | logger.add_scalar('val/cls_acc', acc_record.avg, epoch+1)
123 |
124 | info = 'test_Epoch:{:03d}/{:03d}\t run_time:{:.2f}\t cls_loss:{:.3f}\t cls_acc:{:.2f}\n'.format(
125 | epoch+1, args.epoch, run_time, loss_record.avg, acc_record.avg)
126 | print(info)
127 |
128 | scheduler.step()
129 |
130 | # save checkpoint
131 | if (epoch+1) in args.milestones or epoch+1==args.epoch or (epoch+1)%args.save_interval==0:
132 | state_dict = dict(epoch=epoch+1, state_dict=model.state_dict(), acc=acc_record.avg)
133 | name = osp.join(exp_path, 'ckpt/{:03d}.pth'.format(epoch+1))
134 | os.makedirs(osp.dirname(name), exist_ok=True)
135 | torch.save(state_dict, name)
136 |
137 | # save best
138 | if acc_record.avg > best_acc:
139 | state_dict = dict(epoch=epoch+1, state_dict=model.state_dict(), acc=acc_record.avg)
140 | name = osp.join(exp_path, 'ckpt/best.pth')
141 | os.makedirs(osp.dirname(name), exist_ok=True)
142 | torch.save(state_dict, name)
143 | best_acc = acc_record.avg
144 |
145 | print('best_acc: {:.2f}'.format(best_acc))
146 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import numpy as np
4 |
5 | import torch
6 | from torch.nn import init
7 |
8 | class AverageMeter(object):
9 | """Computes and stores the average and current value"""
10 | def __init__(self):
11 | self.reset()
12 |
13 | def reset(self):
14 | self.count = 0
15 | self.sum = 0.0
16 | self.val = 0.0
17 | self.avg = 0.0
18 |
19 | def update(self, val, n=1):
20 | self.val = val
21 | self.sum += val * n
22 | self.count += n
23 | self.avg = self.sum / self.count
24 |
25 | def accuracy(output, target, topk=(1,)):
26 | """Computes the precision@k for the specified values of k"""
27 | maxk = max(topk)
28 | batch_size = target.size(0)
29 |
30 | _, pred = output.topk(maxk, 1, True, True)
31 | pred = pred.t()
32 | correct = pred.eq(target.view(1, -1).expand_as(pred))
33 |
34 | res = []
35 | for k in topk:
36 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
37 | res.append(correct_k.mul_(100.0 / batch_size))
38 | return res
39 |
40 | def norm(x):
41 |
42 | n = np.linalg.norm(x)
43 | return x / n
44 |
--------------------------------------------------------------------------------
/wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class wrapper(nn.Module):
6 |
7 | def __init__(self, module):
8 |
9 | super(wrapper, self).__init__()
10 |
11 | self.backbone = module
12 | feat_dim = list(module.children())[-1].in_features
13 | self.proj_head = nn.Sequential(
14 | nn.Linear(feat_dim, feat_dim),
15 | nn.ReLU(inplace=True),
16 | nn.Linear(feat_dim, feat_dim)
17 | )
18 |
19 | def forward(self, x, bb_grad=True):
20 |
21 | feats, out = self.backbone(x, is_feat=True)
22 | feat = feats[-1].view(feats[-1].size(0), -1)
23 | if not bb_grad:
24 | feat = feat.detach()
25 |
26 | return out, self.proj_head(feat), feat
27 |
28 |
--------------------------------------------------------------------------------