├── README.md ├── cifar.py ├── command.sh ├── frm.png ├── models ├── ShuffleNetv1.py ├── ShuffleNetv2.py ├── __init__.py ├── classifier.py ├── mobilenetv2.py ├── resnet.py ├── resnetv2.py ├── util.py ├── vgg.py └── wrn.py ├── student.py ├── teacher.py ├── utils.py └── wrapper.py /README.md: -------------------------------------------------------------------------------- 1 | # SSKD 2 | This repo is the implementation of paper [Knowledge Distillation Meets Self-Supervision](https://arxiv.org/abs/2006.07114) (ECCV 2020). 3 | 4 | 5 | 6 | ## Prerequisite 7 | This repo is tested with Ubuntu 16.04.5, Python 3.7, PyTorch 1.5.0, CUDA 10.2. 8 | Make sure to install pytorch, torchvision, tensorboardX, numpy before using this repo. 9 | 10 | ## Running 11 | 12 | ### Teacher Training 13 | An example of teacher training is: 14 | ``` 15 | python teacher.py --arch wrn_40_2 --lr 0.05 --gpu-id 0 16 | ``` 17 | where you can specify the architecture via flag `--arch` 18 | 19 | You can also download all the pre-trained teacher models [here](https://drive.google.com/drive/folders/1vJ0VdeFRd9a50ObbBD8SslBtmqmj8p8r?usp=sharing). 20 | If you want to run `student.py` directly, you have to re-organise the directory. For instance, when you download *vgg13.pth*, you have to make a directory for it, say *teacher_vgg13*, and then make a new directory *ckpt* inside *teacher_vgg13*. Move the *vgg13.pth* into *teacher_vgg13/ckpt* and rename it as *best.pth*. If you want a simpler way to use pre-trained model, you can edit the code in `student.py` (line 90). 21 | 22 | ### Student Training 23 | An example of student training is: 24 | ``` 25 | python student.py --t-path ./experiments/teacher_wrn_40_2_seed0/ --s-arch wrn_16_2 --lr 0.05 --gpu-id 0 26 | ``` 27 | The meanings of flags are: 28 | > `--t-path`: teacher's checkpoint path. Automatically search the checkpoint containing 'best' keyword in its name. 29 | 30 | > `--s-arch`: student's architecture. 31 | 32 | All the commands can be found in `command.sh` 33 | 34 | ## Results (Top-1 Acc) on CIFAR100 35 | 36 | ### Similar-Architecture 37 | 38 | | Teacher
Student | wrn40-2
wrn16-2 | wrn40-2
wrn40-1 | resnet56
resnet20 | resnet32x4
resnet8x4 | vgg13
vgg8 | 39 | |:---------------:|:-----------------:|:-----------------:|:-----------------:|:--------------------:|:-----------:| 40 | | Teacher
Student | 76.46
73.64 | 76.46
72.24 | 73.44
69.63 | 79.63
72.51 | 75.38
70.68 | 41 | | KD | 74.92 | 73.54 | 70.66 | 73.33 | 72.98 | 42 | | FitNet | 75.75 | 74.12 | 71.60 | 74.31 | 73.54 | 43 | | AT | 75.28 | 74.45 | **71.78** | 74.26 | 73.62 | 44 | | SP | 75.34 | 73.15 | 71.48 | 74.74 | 73.44 | 45 | | VID | 74.79 | 74.20 | 71.71 | 74.82 | 73.96 | 46 | | RKD | 75.40 | 73.87 | 71.48 | 74.47 | 73.72 | 47 | | PKT | 76.01 | 74.40 | 71.44 | 74.17 | 73.37 | 48 | | AB | 68.89 | 75.06 | 71.49 | 74.45 | 74.27 | 49 | | FT | 75.15 | 74.37 | 71.52 | 75.02 | 73.42 | 50 | | CRD | **76.04** | 75.52 | 71.68 | 75.90 | 74.06 | 51 | | **SSKD** | **76.04** | **76.13** | 71.49 | **76.20** | **75.33** | 52 | 53 | ### Cross-Architecture 54 | 55 | | Teacher
Student | vgg13
MobieleNetV2 | ResNet50
MobileNetV2 | ResNet50
vgg8 | resnet32x4
ShuffleV1 | resnet32x4
ShuffleV2 | wrn40-2
ShuffleV1| 56 | |:---------------:|:-----------------:|:-----------------:|:-----------------:|:--------------------:|:-----------:|:-------------:| 57 | | Teacher
Student | 75.38
65.79 | 79.10
65.79 | 79.10
70.68 | 79.63
70.77 | 79.63
73.12 | 76.46
70.77 | 58 | | KD | 67.37 | 67.35| 73.81| 74.07| 74.45| 74.83| 59 | | FitNet |68.58 | 68.54 | 73.84 | 74.82 | 75.11 | 75.55 | 60 | | AT | 69.34 | 69.28 | 73.45 | 74.76 | 75.30 | 75.61 | 61 | | SP | 66.89 | 68.99 | 73.86 | 73.80 | 75.15 | 75.56 | 62 | | VID | 66.91 | 68.88 | 73.75 | 74.28 | 75.78 | 75.36 | 63 | | RKD | 68.50 | 68.46 | 73.73 | 74.20 | 75.74 | 75.45 | 64 | | PKT | 67.89 | 68.44 | 73.53 | 74.06 | 75.18 | 75.51 | 65 | | AB | 68.86 | 69.32 | 74.20 | 76.24 | 75.66 | 76.58 | 66 | | FT | 69.19 | 69.01 | 73.58 | 74.31 | 74.95 | 75.18 | 67 | | CRD | 68.49 | 70.32 | 74.42 | 75.46 | 75.72 | 75.96 | 68 | | **SSKD** | **71.53** | **72.57** | **75.76** | **78.44** | **78.61** | **77.40** | 69 | 70 | ## Citation 71 | If you find this repo useful for your research, please consider citing the paper 72 | ``` 73 | @inproceedings{xu2020knowledge, 74 | title={Knowledge Distillation Meets Self-Supervision}, 75 | author={Xu, Guodong and Liu, Ziwei and Li, Xiaoxiao and Loy, Chen Change}, 76 | booktitle={European Conference on Computer Vision (ECCV)}, 77 | year={2020}, 78 | } 79 | ``` 80 | ## Acknowledgement 81 | The implementation of `models` is borrowed from [CRD](https://github.com/HobbitLong/RepDistiller) 82 | -------------------------------------------------------------------------------- /cifar.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from PIL import Image 3 | import os 4 | import os.path 5 | import numpy as np 6 | import sys 7 | 8 | import pickle 9 | import torch 10 | import torch.utils.data as data 11 | 12 | from itertools import permutations 13 | 14 | class VisionDataset(data.Dataset): 15 | _repr_indent = 4 16 | 17 | def __init__(self, root, transforms=None, transform=None, target_transform=None): 18 | if isinstance(root, torch._six.string_classes): 19 | root = os.path.expanduser(root) 20 | self.root = root 21 | 22 | has_transforms = transforms is not None 23 | has_separate_transform = transform is not None or target_transform is not None 24 | if has_transforms and has_separate_transform: 25 | raise ValueError("Only transforms or transform/target_transform can " 26 | "be passed as argument") 27 | 28 | # for backwards-compatibility 29 | self.transform = transform 30 | self.target_transform = target_transform 31 | 32 | if has_separate_transform: 33 | transforms = StandardTransform(transform, target_transform) 34 | self.transforms = transforms 35 | 36 | def __getitem__(self, index): 37 | raise NotImplementedError 38 | 39 | def __len__(self): 40 | raise NotImplementedError 41 | 42 | def __repr__(self): 43 | head = "Dataset " + self.__class__.__name__ 44 | body = ["Number of datapoints: {}".format(self.__len__())] 45 | if self.root is not None: 46 | body.append("Root location: {}".format(self.root)) 47 | body += self.extra_repr().splitlines() 48 | if self.transforms is not None: 49 | body += [repr(self.transforms)] 50 | lines = [head] + [" " * self._repr_indent + line for line in body] 51 | return '\n'.join(lines) 52 | 53 | def _format_transform_repr(self, transform, head): 54 | lines = transform.__repr__().splitlines() 55 | return (["{}{}".format(head, lines[0])] + 56 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 57 | 58 | def extra_repr(self): 59 | return "" 60 | 61 | class CIFAR10(VisionDataset): 62 | base_folder = 'cifar-10-batches-py' 63 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 64 | filename = "cifar-10-python.tar.gz" 65 | tgz_md5 = 'c58f30108f718f92721af3b95e74349a' 66 | train_list = [ 67 | ['data_batch_1', 'c99cafc152244af753f735de768cd75f'], 68 | ['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'], 69 | ['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'], 70 | ['data_batch_4', '634d18415352ddfa80567beed471001a'], 71 | ['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'], 72 | ] 73 | 74 | test_list = [ 75 | ['test_batch', '40351d587109b95175f43aff81a1287e'], 76 | ] 77 | meta = { 78 | 'filename': 'batches.meta', 79 | 'key': 'label_names', 80 | 'md5': '5ff9c542aee3614f3951f8cda6e48888', 81 | } 82 | 83 | def __init__(self, root, train=True, 84 | transform=None, download=False): 85 | 86 | super(CIFAR10, self).__init__(root) 87 | self.transform = transform 88 | 89 | self.train = train # training set or test set 90 | 91 | if download: 92 | raise ValueError('cannot download.') 93 | exit() 94 | #self.download() 95 | 96 | #if not self._check_integrity(): 97 | # raise RuntimeError('Dataset not found or corrupted.' + 98 | # ' You can use download=True to download it') 99 | 100 | if self.train: 101 | downloaded_list = self.train_list 102 | else: 103 | downloaded_list = self.test_list 104 | 105 | self.data = [] 106 | self.targets = [] 107 | 108 | # now load the picked numpy arrays 109 | for file_name, checksum in downloaded_list: 110 | file_path = os.path.join(self.root, self.base_folder, file_name) 111 | with open(file_path, 'rb') as f: 112 | if sys.version_info[0] == 2: 113 | entry = pickle.load(f) 114 | else: 115 | entry = pickle.load(f, encoding='latin1') 116 | self.data.append(entry['data']) 117 | if 'labels' in entry: 118 | self.targets.extend(entry['labels']) 119 | else: 120 | self.targets.extend(entry['fine_labels']) 121 | 122 | self.data = np.vstack(self.data).reshape(-1, 3, 32, 32) 123 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 124 | 125 | self._load_meta() 126 | 127 | def _load_meta(self): 128 | path = os.path.join(self.root, self.base_folder, self.meta['filename']) 129 | #if not check_integrity(path, self.meta['md5']): 130 | # raise RuntimeError('Dataset metadata file not found or corrupted.' + 131 | # ' You can use download=True to download it') 132 | with open(path, 'rb') as infile: 133 | if sys.version_info[0] == 2: 134 | data = pickle.load(infile) 135 | else: 136 | data = pickle.load(infile, encoding='latin1') 137 | self.classes = data[self.meta['key']] 138 | self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} 139 | 140 | def __getitem__(self, index): 141 | 142 | img, target = self.data[index], self.targets[index] 143 | if self.train: 144 | if np.random.rand() < 0.5: 145 | img = img[:,::-1,:] 146 | 147 | img0 = np.rot90(img, 0).copy() 148 | img0 = Image.fromarray(img0) 149 | img0 = self.transform(img0) 150 | 151 | img1 = np.rot90(img, 1).copy() 152 | img1 = Image.fromarray(img1) 153 | img1 = self.transform(img1) 154 | 155 | img2 = np.rot90(img, 2).copy() 156 | img2 = Image.fromarray(img2) 157 | img2 = self.transform(img2) 158 | 159 | img3 = np.rot90(img, 3).copy() 160 | img3 = Image.fromarray(img3) 161 | img3 = self.transform(img3) 162 | 163 | img = torch.stack([img0,img1,img2,img3]) 164 | 165 | return img, target 166 | 167 | 168 | def __len__(self): 169 | return len(self.data) 170 | 171 | def _check_integrity(self): 172 | root = self.root 173 | for fentry in (self.train_list + self.test_list): 174 | filename, md5 = fentry[0], fentry[1] 175 | fpath = os.path.join(root, self.base_folder, filename) 176 | if not check_integrity(fpath, md5): 177 | return False 178 | return True 179 | 180 | def download(self): 181 | import tarfile 182 | 183 | if self._check_integrity(): 184 | print('Files already downloaded and verified') 185 | return 186 | 187 | download_url(self.url, self.root, self.filename, self.tgz_md5) 188 | 189 | # extract file 190 | with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: 191 | tar.extractall(path=self.root) 192 | 193 | def extra_repr(self): 194 | return "Split: {}".format("Train" if self.train is True else "Test") 195 | 196 | 197 | class CIFAR100(CIFAR10): 198 | """`CIFAR100 `_ Dataset. 199 | 200 | This is a subclass of the `CIFAR10` Dataset. 201 | """ 202 | base_folder = 'cifar-100-python' 203 | url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 204 | filename = "cifar-100-python.tar.gz" 205 | tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85' 206 | train_list = [ 207 | ['train', '16019d7e3df5f24257cddd939b257f8d'], 208 | ] 209 | 210 | test_list = [ 211 | ['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'], 212 | ] 213 | meta = { 214 | 'filename': 'meta', 215 | 'key': 'fine_label_names', 216 | 'md5': '7973b15100ade9c7d40fb424638fde48', 217 | } 218 | -------------------------------------------------------------------------------- /command.sh: -------------------------------------------------------------------------------- 1 | # teacher training 2 | python teacher.py --arch wrn_40_2 --lr 0.05 --gpu-id 0 3 | python teacher.py --arch wrn_40_1 --lr 0.05 --gpu-id 0 4 | python teacher.py --arch wrn_16_2 --lr 0.05 --gpu-id 0 5 | python teacher.py --arch vgg13 --lr 0.05 --gpu-id 0 6 | python teacher.py --arch vgg8 --lr 0.05 --gpu-id 0 7 | python teacher.py --arch resnet56 --lr 0.05 --gpu-id 0 8 | python teacher.py --arch resnet20 --lr 0.05 --gpu-id 0 9 | python teacher.py --arch resnet32x4 --lr 0.05 --gpu-id 0 10 | python teacher.py --arch resnet8x4 --lr 0.05 --gpu-id 0 11 | python teacher.py --arch ResNet50 --lr 0.05 --gpu-id 0 12 | python teacher.py --arch ShuffleV1 --lr 0.01 --gpu-id 0 13 | python teacher.py --arch ShuffleV2 --lr 0.01 --gpu-id 0 14 | python teacher.py --arch MobileNetV2 --lr 0.01 --gpu-id 0 15 | 16 | 17 | # student training 18 | 19 | # similar-architecture 20 | python student.py --t-path ./experiments/teacher_wrn_40_2_seed0/ --s-arch wrn_16_2 --lr 0.05 --gpu-id 0 21 | python student.py --t-path ./experiments/teacher_wrn_40_2_seed0/ --s-arch wrn_40_1 --lr 0.05 --gpu-id 0 22 | python student.py --t-path ./experiments/teacher_resnet56_seed0/ --s-arch resnet20 --lr 0.05 --gpu-id 0 23 | python student.py --t-path ./experiments/teacher_resnet32x4_seed0/ --s-arch resnet8x4 --lr 0.05 --gpu-id 0 24 | python student.py --t-path ./experiments/teacher_vgg13_seed0/ --s-arch vgg8 --lr 0.05 --gpu-id 0 25 | # different-architecture 26 | python student.py --t-path ./experiments/teacher_vgg13_seed0/ --s-arch MobileNetV2 --lr 0.01 --gpu-id 0 27 | python student.py --t-path ./experiments/teacher_ResNet50_seed0/ --s-arch MobileNetV2 --lr 0.01 --gpu-id 0 28 | python student.py --t-path ./experiments/teacher_ResNet50_seed0/ --s-arch vgg8 --lr 0.05 --gpu-id 0 29 | python student.py --t-path ./experiments/teacher_resnet32x4_seed0/ --s-arch ShuffleV1 --lr 0.01 --gpu-id 0 30 | python student.py --t-path ./experiments/teacher_resnet32x4_seed0/ --s-arch ShuffleV2 --lr 0.01 --gpu-id 0 31 | python student.py --t-path ./experiments/teacher_wrn_40_2_seed0/ --s-arch ShuffleV1 --lr 0.01 --gpu-id 0 32 | 33 | -------------------------------------------------------------------------------- /frm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xuguodong03/SSKD/661b972c124a83c32dcd0e203390ba637075fe92/frm.png -------------------------------------------------------------------------------- /models/ShuffleNetv1.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNet in PyTorch. 2 | See the paper "ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices" for more details. 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class ShuffleBlock(nn.Module): 10 | def __init__(self, groups): 11 | super(ShuffleBlock, self).__init__() 12 | self.groups = groups 13 | 14 | def forward(self, x): 15 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 16 | N,C,H,W = x.size() 17 | g = self.groups 18 | return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).reshape(N,C,H,W) 19 | 20 | 21 | class Bottleneck(nn.Module): 22 | def __init__(self, in_planes, out_planes, stride, groups, is_last=False): 23 | super(Bottleneck, self).__init__() 24 | self.is_last = is_last 25 | self.stride = stride 26 | 27 | mid_planes = int(out_planes/4) 28 | g = 1 if in_planes == 24 else groups 29 | self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False) 30 | self.bn1 = nn.BatchNorm2d(mid_planes) 31 | self.shuffle1 = ShuffleBlock(groups=g) 32 | self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=1, groups=mid_planes, bias=False) 33 | self.bn2 = nn.BatchNorm2d(mid_planes) 34 | self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False) 35 | self.bn3 = nn.BatchNorm2d(out_planes) 36 | 37 | self.shortcut = nn.Sequential() 38 | if stride == 2: 39 | self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1)) 40 | 41 | def forward(self, x): 42 | out = F.relu(self.bn1(self.conv1(x))) 43 | out = self.shuffle1(out) 44 | out = F.relu(self.bn2(self.conv2(out))) 45 | out = self.bn3(self.conv3(out)) 46 | res = self.shortcut(x) 47 | preact = torch.cat([out, res], 1) if self.stride == 2 else out+res 48 | out = F.relu(preact) 49 | # out = F.relu(torch.cat([out, res], 1)) if self.stride == 2 else F.relu(out+res) 50 | if self.is_last: 51 | return out, preact 52 | else: 53 | return out 54 | 55 | 56 | class ShuffleNet(nn.Module): 57 | def __init__(self, cfg, num_classes=10): 58 | super(ShuffleNet, self).__init__() 59 | out_planes = cfg['out_planes'] 60 | num_blocks = cfg['num_blocks'] 61 | groups = cfg['groups'] 62 | 63 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 64 | self.bn1 = nn.BatchNorm2d(24) 65 | self.in_planes = 24 66 | self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups) 67 | self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups) 68 | self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups) 69 | self.linear = nn.Linear(out_planes[2], num_classes) 70 | 71 | def _make_layer(self, out_planes, num_blocks, groups): 72 | layers = [] 73 | for i in range(num_blocks): 74 | stride = 2 if i == 0 else 1 75 | cat_planes = self.in_planes if i == 0 else 0 76 | layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, 77 | stride=stride, 78 | groups=groups, 79 | is_last=(i == num_blocks - 1))) 80 | self.in_planes = out_planes 81 | return nn.Sequential(*layers) 82 | 83 | def get_feat_modules(self): 84 | feat_m = nn.ModuleList([]) 85 | feat_m.append(self.conv1) 86 | feat_m.append(self.bn1) 87 | feat_m.append(self.layer1) 88 | feat_m.append(self.layer2) 89 | feat_m.append(self.layer3) 90 | return feat_m 91 | 92 | def get_bn_before_relu(self): 93 | raise NotImplementedError('ShuffleNet currently is not supported for "Overhaul" teacher') 94 | 95 | def forward(self, x, is_feat=False, preact=False): 96 | out = F.relu(self.bn1(self.conv1(x))) 97 | f0 = out 98 | out, f1_pre = self.layer1(out) 99 | f1 = out 100 | out, f2_pre = self.layer2(out) 101 | f2 = out 102 | out, f3_pre = self.layer3(out) 103 | f3 = out 104 | out = F.avg_pool2d(out, 4) 105 | out = out.view(out.size(0), -1) 106 | f4 = out 107 | out = self.linear(out) 108 | 109 | if is_feat: 110 | if preact: 111 | return [f0, f1_pre, f2_pre, f3_pre, f4], out 112 | else: 113 | return [f0, f1, f2, f3, f4], out 114 | else: 115 | return out 116 | 117 | 118 | def ShuffleV1(**kwargs): 119 | cfg = { 120 | 'out_planes': [240, 480, 960], 121 | 'num_blocks': [4, 8, 4], 122 | 'groups': 3 123 | } 124 | return ShuffleNet(cfg, **kwargs) 125 | 126 | 127 | if __name__ == '__main__': 128 | 129 | x = torch.randn(2, 3, 32, 32) 130 | net = ShuffleV1(num_classes=100) 131 | import time 132 | a = time.time() 133 | feats, logit = net(x, is_feat=True, preact=True) 134 | b = time.time() 135 | print(b - a) 136 | for f in feats: 137 | print(f.shape, f.min().item()) 138 | print(logit.shape) 139 | -------------------------------------------------------------------------------- /models/ShuffleNetv2.py: -------------------------------------------------------------------------------- 1 | '''ShuffleNetV2 in PyTorch. 2 | See the paper "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" for more details. 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class ShuffleBlock(nn.Module): 10 | def __init__(self, groups=2): 11 | super(ShuffleBlock, self).__init__() 12 | self.groups = groups 13 | 14 | def forward(self, x): 15 | '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]''' 16 | N, C, H, W = x.size() 17 | g = self.groups 18 | return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W) 19 | 20 | 21 | class SplitBlock(nn.Module): 22 | def __init__(self, ratio): 23 | super(SplitBlock, self).__init__() 24 | self.ratio = ratio 25 | 26 | def forward(self, x): 27 | c = int(x.size(1) * self.ratio) 28 | return x[:, :c, :, :], x[:, c:, :, :] 29 | 30 | 31 | class BasicBlock(nn.Module): 32 | def __init__(self, in_channels, split_ratio=0.5, is_last=False): 33 | super(BasicBlock, self).__init__() 34 | self.is_last = is_last 35 | self.split = SplitBlock(split_ratio) 36 | in_channels = int(in_channels * split_ratio) 37 | self.conv1 = nn.Conv2d(in_channels, in_channels, 38 | kernel_size=1, bias=False) 39 | self.bn1 = nn.BatchNorm2d(in_channels) 40 | self.conv2 = nn.Conv2d(in_channels, in_channels, 41 | kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False) 42 | self.bn2 = nn.BatchNorm2d(in_channels) 43 | self.conv3 = nn.Conv2d(in_channels, in_channels, 44 | kernel_size=1, bias=False) 45 | self.bn3 = nn.BatchNorm2d(in_channels) 46 | self.shuffle = ShuffleBlock() 47 | 48 | def forward(self, x): 49 | x1, x2 = self.split(x) 50 | out = F.relu(self.bn1(self.conv1(x2))) 51 | out = self.bn2(self.conv2(out)) 52 | preact = self.bn3(self.conv3(out)) 53 | out = F.relu(preact) 54 | # out = F.relu(self.bn3(self.conv3(out))) 55 | preact = torch.cat([x1, preact], 1) 56 | out = torch.cat([x1, out], 1) 57 | out = self.shuffle(out) 58 | if self.is_last: 59 | return out, preact 60 | else: 61 | return out 62 | 63 | 64 | class DownBlock(nn.Module): 65 | def __init__(self, in_channels, out_channels): 66 | super(DownBlock, self).__init__() 67 | mid_channels = out_channels // 2 68 | # left 69 | self.conv1 = nn.Conv2d(in_channels, in_channels, 70 | kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False) 71 | self.bn1 = nn.BatchNorm2d(in_channels) 72 | self.conv2 = nn.Conv2d(in_channels, mid_channels, 73 | kernel_size=1, bias=False) 74 | self.bn2 = nn.BatchNorm2d(mid_channels) 75 | # right 76 | self.conv3 = nn.Conv2d(in_channels, mid_channels, 77 | kernel_size=1, bias=False) 78 | self.bn3 = nn.BatchNorm2d(mid_channels) 79 | self.conv4 = nn.Conv2d(mid_channels, mid_channels, 80 | kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False) 81 | self.bn4 = nn.BatchNorm2d(mid_channels) 82 | self.conv5 = nn.Conv2d(mid_channels, mid_channels, 83 | kernel_size=1, bias=False) 84 | self.bn5 = nn.BatchNorm2d(mid_channels) 85 | 86 | self.shuffle = ShuffleBlock() 87 | 88 | def forward(self, x): 89 | # left 90 | out1 = self.bn1(self.conv1(x)) 91 | out1 = F.relu(self.bn2(self.conv2(out1))) 92 | # right 93 | out2 = F.relu(self.bn3(self.conv3(x))) 94 | out2 = self.bn4(self.conv4(out2)) 95 | out2 = F.relu(self.bn5(self.conv5(out2))) 96 | # concat 97 | out = torch.cat([out1, out2], 1) 98 | out = self.shuffle(out) 99 | return out 100 | 101 | 102 | class ShuffleNetV2(nn.Module): 103 | def __init__(self, net_size, num_classes=10): 104 | super(ShuffleNetV2, self).__init__() 105 | out_channels = configs[net_size]['out_channels'] 106 | num_blocks = configs[net_size]['num_blocks'] 107 | 108 | # self.conv1 = nn.Conv2d(3, 24, kernel_size=3, 109 | # stride=1, padding=1, bias=False) 110 | self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False) 111 | self.bn1 = nn.BatchNorm2d(24) 112 | self.in_channels = 24 113 | self.layer1 = self._make_layer(out_channels[0], num_blocks[0]) 114 | self.layer2 = self._make_layer(out_channels[1], num_blocks[1]) 115 | self.layer3 = self._make_layer(out_channels[2], num_blocks[2]) 116 | self.conv2 = nn.Conv2d(out_channels[2], out_channels[3], 117 | kernel_size=1, stride=1, padding=0, bias=False) 118 | self.bn2 = nn.BatchNorm2d(out_channels[3]) 119 | self.linear = nn.Linear(out_channels[3], num_classes) 120 | 121 | def _make_layer(self, out_channels, num_blocks): 122 | layers = [DownBlock(self.in_channels, out_channels)] 123 | for i in range(num_blocks): 124 | layers.append(BasicBlock(out_channels, is_last=(i == num_blocks - 1))) 125 | self.in_channels = out_channels 126 | return nn.Sequential(*layers) 127 | 128 | def get_feat_modules(self): 129 | feat_m = nn.ModuleList([]) 130 | feat_m.append(self.conv1) 131 | feat_m.append(self.bn1) 132 | feat_m.append(self.layer1) 133 | feat_m.append(self.layer2) 134 | feat_m.append(self.layer3) 135 | return feat_m 136 | 137 | def get_bn_before_relu(self): 138 | raise NotImplementedError('ShuffleNetV2 currently is not supported for "Overhaul" teacher') 139 | 140 | def forward(self, x, is_feat=False, preact=False): 141 | out = F.relu(self.bn1(self.conv1(x))) 142 | # out = F.max_pool2d(out, 3, stride=2, padding=1) 143 | f0 = out 144 | out, f1_pre = self.layer1(out) 145 | f1 = out 146 | out, f2_pre = self.layer2(out) 147 | f2 = out 148 | out, f3_pre = self.layer3(out) 149 | f3 = out 150 | out = F.relu(self.bn2(self.conv2(out))) 151 | out = F.avg_pool2d(out, 4) 152 | out = out.view(out.size(0), -1) 153 | f4 = out 154 | out = self.linear(out) 155 | if is_feat: 156 | if preact: 157 | return [f0, f1_pre, f2_pre, f3_pre, f4], out 158 | else: 159 | return [f0, f1, f2, f3, f4], out 160 | else: 161 | return out 162 | 163 | 164 | configs = { 165 | 0.2: { 166 | 'out_channels': (40, 80, 160, 512), 167 | 'num_blocks': (3, 3, 3) 168 | }, 169 | 170 | 0.3: { 171 | 'out_channels': (40, 80, 160, 512), 172 | 'num_blocks': (3, 7, 3) 173 | }, 174 | 175 | 0.5: { 176 | 'out_channels': (48, 96, 192, 1024), 177 | 'num_blocks': (3, 7, 3) 178 | }, 179 | 180 | 1: { 181 | 'out_channels': (116, 232, 464, 1024), 182 | 'num_blocks': (3, 7, 3) 183 | }, 184 | 1.5: { 185 | 'out_channels': (176, 352, 704, 1024), 186 | 'num_blocks': (3, 7, 3) 187 | }, 188 | 2: { 189 | 'out_channels': (224, 488, 976, 2048), 190 | 'num_blocks': (3, 7, 3) 191 | } 192 | } 193 | 194 | 195 | def ShuffleV2(**kwargs): 196 | model = ShuffleNetV2(net_size=1, **kwargs) 197 | return model 198 | 199 | 200 | if __name__ == '__main__': 201 | net = ShuffleV2(num_classes=100) 202 | x = torch.randn(3, 3, 32, 32) 203 | import time 204 | a = time.time() 205 | feats, logit = net(x, is_feat=True, preact=True) 206 | b = time.time() 207 | print(b - a) 208 | for f in feats: 209 | print(f.shape, f.min().item()) 210 | print(logit.shape) 211 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import resnet8, resnet14, resnet20, resnet32, resnet44, resnet56, resnet110, resnet8x4, resnet32x4, resnet14x05, resnet20x05, resnet20x0375 2 | from .resnetv2 import ResNet50 3 | from .wrn import wrn_16_1, wrn_16_2, wrn_40_1, wrn_40_2 4 | from .vgg import vgg19_bn, vgg16_bn, vgg13_bn, vgg11_bn, vgg8_bn 5 | from .mobilenetv2 import mobile_half 6 | from .ShuffleNetv1 import ShuffleV1 7 | from .ShuffleNetv2 import ShuffleV2 8 | 9 | model_dict = { 10 | 'resnet8': resnet8, 11 | 'resnet14': resnet14, 12 | 'resnet20': resnet20, 13 | 'resnet32': resnet32, 14 | 'resnet44': resnet44, 15 | 'resnet56': resnet56, 16 | 'resnet110': resnet110, 17 | 'resnet8x4': resnet8x4, 18 | 'resnet32x4': resnet32x4, 19 | 'ResNet50': ResNet50, 20 | 'wrn_16_1': wrn_16_1, 21 | 'wrn_16_2': wrn_16_2, 22 | 'wrn_40_1': wrn_40_1, 23 | 'wrn_40_2': wrn_40_2, 24 | 'vgg8': vgg8_bn, 25 | 'vgg11': vgg11_bn, 26 | 'vgg13': vgg13_bn, 27 | 'vgg16': vgg16_bn, 28 | 'vgg19': vgg19_bn, 29 | 'MobileNetV2': mobile_half, 30 | 'ShuffleV1': ShuffleV1, 31 | 'ShuffleV2': ShuffleV2, 32 | 'resnet14x05': resnet14x05, 33 | 'resnet20x05': resnet20x05, 34 | 'resnet20x0375': resnet20x0375, 35 | } 36 | -------------------------------------------------------------------------------- /models/classifier.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn as nn 4 | 5 | 6 | ######################################### 7 | # ===== Classifiers ===== # 8 | ######################################### 9 | 10 | class LinearClassifier(nn.Module): 11 | 12 | def __init__(self, dim_in, n_label=10): 13 | super(LinearClassifier, self).__init__() 14 | 15 | self.net = nn.Linear(dim_in, n_label) 16 | 17 | def forward(self, x): 18 | return self.net(x) 19 | 20 | 21 | class NonLinearClassifier(nn.Module): 22 | 23 | def __init__(self, dim_in, n_label=10, p=0.1): 24 | super(NonLinearClassifier, self).__init__() 25 | 26 | self.net = nn.Sequential( 27 | nn.Linear(dim_in, 200), 28 | nn.Dropout(p=p), 29 | nn.BatchNorm1d(200), 30 | nn.ReLU(inplace=True), 31 | nn.Linear(200, n_label), 32 | ) 33 | 34 | def forward(self, x): 35 | return self.net(x) 36 | -------------------------------------------------------------------------------- /models/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | MobileNetV2 implementation used in 3 | 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import math 9 | 10 | __all__ = ['mobilenetv2_T_w', 'mobile_half'] 11 | 12 | BN = None 13 | 14 | 15 | def conv_bn(inp, oup, stride): 16 | return nn.Sequential( 17 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 18 | nn.BatchNorm2d(oup), 19 | nn.ReLU(inplace=True) 20 | ) 21 | 22 | 23 | def conv_1x1_bn(inp, oup): 24 | return nn.Sequential( 25 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 26 | nn.BatchNorm2d(oup), 27 | nn.ReLU(inplace=True) 28 | ) 29 | 30 | 31 | class InvertedResidual(nn.Module): 32 | def __init__(self, inp, oup, stride, expand_ratio): 33 | super(InvertedResidual, self).__init__() 34 | self.blockname = None 35 | 36 | self.stride = stride 37 | assert stride in [1, 2] 38 | 39 | self.use_res_connect = self.stride == 1 and inp == oup 40 | 41 | self.conv = nn.Sequential( 42 | # pw 43 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 44 | nn.BatchNorm2d(inp * expand_ratio), 45 | nn.ReLU(inplace=True), 46 | # dw 47 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, stride, 1, groups=inp * expand_ratio, bias=False), 48 | nn.BatchNorm2d(inp * expand_ratio), 49 | nn.ReLU(inplace=True), 50 | # pw-linear 51 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 52 | nn.BatchNorm2d(oup), 53 | ) 54 | self.names = ['0', '1', '2', '3', '4', '5', '6', '7'] 55 | 56 | def forward(self, x): 57 | t = x 58 | if self.use_res_connect: 59 | return t + self.conv(x) 60 | else: 61 | return self.conv(x) 62 | 63 | 64 | class MobileNetV2(nn.Module): 65 | """mobilenetV2""" 66 | def __init__(self, T, 67 | feature_dim, 68 | input_size=32, 69 | width_mult=1., 70 | remove_avg=False): 71 | super(MobileNetV2, self).__init__() 72 | self.remove_avg = remove_avg 73 | 74 | # setting of inverted residual blocks 75 | self.interverted_residual_setting = [ 76 | # t, c, n, s 77 | [1, 16, 1, 1], 78 | [T, 24, 2, 1], 79 | [T, 32, 3, 2], 80 | [T, 64, 4, 2], 81 | [T, 96, 3, 1], 82 | [T, 160, 3, 2], 83 | [T, 320, 1, 1], 84 | ] 85 | 86 | # building first layer 87 | assert input_size % 32 == 0 88 | input_channel = int(32 * width_mult) 89 | self.conv1 = conv_bn(3, input_channel, 2) 90 | 91 | # building inverted residual blocks 92 | self.blocks = nn.ModuleList([]) 93 | for t, c, n, s in self.interverted_residual_setting: 94 | output_channel = int(c * width_mult) 95 | layers = [] 96 | strides = [s] + [1] * (n - 1) 97 | for stride in strides: 98 | layers.append( 99 | InvertedResidual(input_channel, output_channel, stride, t) 100 | ) 101 | input_channel = output_channel 102 | self.blocks.append(nn.Sequential(*layers)) 103 | 104 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 105 | self.conv2 = conv_1x1_bn(input_channel, self.last_channel) 106 | 107 | H = input_size // (32//2) 108 | self.avgpool = nn.AvgPool2d(H, ceil_mode=True) 109 | 110 | # building classifier 111 | #self.classifier = nn.Sequential( 112 | # # nn.Dropout(0.5), 113 | # nn.Linear(self.last_channel, feature_dim), 114 | #) 115 | self.classifier = nn.Linear(self.last_channel, feature_dim) 116 | 117 | self._initialize_weights() 118 | print(T, width_mult) 119 | 120 | def get_bn_before_relu(self): 121 | bn1 = self.blocks[1][-1].conv[-1] 122 | bn2 = self.blocks[2][-1].conv[-1] 123 | bn3 = self.blocks[4][-1].conv[-1] 124 | bn4 = self.blocks[6][-1].conv[-1] 125 | return [bn1, bn2, bn3, bn4] 126 | 127 | def get_feat_modules(self): 128 | feat_m = nn.ModuleList([]) 129 | feat_m.append(self.conv1) 130 | feat_m.append(self.blocks) 131 | return feat_m 132 | 133 | def forward(self, x, is_feat=False, preact=False): 134 | 135 | out = self.conv1(x) 136 | f0 = out 137 | 138 | out = self.blocks[0](out) 139 | out = self.blocks[1](out) 140 | f1 = out 141 | out = self.blocks[2](out) 142 | f2 = out 143 | out = self.blocks[3](out) 144 | out = self.blocks[4](out) 145 | f3 = out 146 | out = self.blocks[5](out) 147 | out = self.blocks[6](out) 148 | f4 = out 149 | 150 | out = self.conv2(out) 151 | 152 | if not self.remove_avg: 153 | out = self.avgpool(out) 154 | out = out.view(out.size(0), -1) 155 | f5 = out 156 | out = self.classifier(out) 157 | 158 | if is_feat: 159 | return [f0, f1, f2, f3, f4, f5], out 160 | else: 161 | return out 162 | 163 | def _initialize_weights(self): 164 | for m in self.modules(): 165 | if isinstance(m, nn.Conv2d): 166 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 167 | m.weight.data.normal_(0, math.sqrt(2. / n)) 168 | if m.bias is not None: 169 | m.bias.data.zero_() 170 | elif isinstance(m, nn.BatchNorm2d): 171 | m.weight.data.fill_(1) 172 | m.bias.data.zero_() 173 | elif isinstance(m, nn.Linear): 174 | n = m.weight.size(1) 175 | m.weight.data.normal_(0, 0.01) 176 | m.bias.data.zero_() 177 | 178 | 179 | def mobilenetv2_T_w(T, W, feature_dim=100): 180 | model = MobileNetV2(T=T, feature_dim=feature_dim, width_mult=W) 181 | return model 182 | 183 | 184 | def mobile_half(num_classes): 185 | return mobilenetv2_T_w(6, 0.5, num_classes) 186 | 187 | 188 | if __name__ == '__main__': 189 | x = torch.randn(2, 3, 32, 32) 190 | 191 | net = mobile_half(100) 192 | 193 | feats, logit = net(x, is_feat=True, preact=True) 194 | for f in feats: 195 | print(f.shape, f.min().item()) 196 | print(logit.shape) 197 | 198 | for m in net.get_bn_before_relu(): 199 | if isinstance(m, nn.BatchNorm2d): 200 | print('pass') 201 | else: 202 | print('warning') 203 | 204 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | '''Resnet for cifar dataset. 4 | Ported form 5 | https://github.com/facebook/fb.resnet.torch 6 | and 7 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 8 | (c) YANG, Wei 9 | ''' 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import math 13 | 14 | 15 | __all__ = ['resnet'] 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | """3x3 convolution with padding""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False): 28 | super(BasicBlock, self).__init__() 29 | self.is_last = is_last 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | preact = out 53 | out = F.relu(out) 54 | if self.is_last: 55 | return out, preact 56 | else: 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None, is_last=False): 64 | super(Bottleneck, self).__init__() 65 | self.is_last = is_last 66 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 67 | self.bn1 = nn.BatchNorm2d(planes) 68 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 69 | padding=1, bias=False) 70 | self.bn2 = nn.BatchNorm2d(planes) 71 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 72 | self.bn3 = nn.BatchNorm2d(planes * 4) 73 | self.relu = nn.ReLU(inplace=True) 74 | self.downsample = downsample 75 | self.stride = stride 76 | 77 | def forward(self, x): 78 | residual = x 79 | 80 | out = self.conv1(x) 81 | out = self.bn1(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv2(out) 85 | out = self.bn2(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | 91 | if self.downsample is not None: 92 | residual = self.downsample(x) 93 | 94 | out += residual 95 | preact = out 96 | out = F.relu(out) 97 | if self.is_last: 98 | return out, preact 99 | else: 100 | return out 101 | 102 | 103 | class ResNet(nn.Module): 104 | 105 | def __init__(self, depth, num_filters, block_name='BasicBlock', num_classes=10): 106 | super(ResNet, self).__init__() 107 | # Model type specifies number of layers for CIFAR-10 model 108 | if block_name.lower() == 'basicblock': 109 | assert (depth - 2) % 6 == 0, 'When use basicblock, depth should be 6n+2, e.g. 20, 32, 44, 56, 110, 1202' 110 | n = (depth - 2) // 6 111 | block = BasicBlock 112 | elif block_name.lower() == 'bottleneck': 113 | assert (depth - 2) % 9 == 0, 'When use bottleneck, depth should be 9n+2, e.g. 20, 29, 47, 56, 110, 1199' 114 | n = (depth - 2) // 9 115 | block = Bottleneck 116 | else: 117 | raise ValueError('block_name shoule be Basicblock or Bottleneck') 118 | 119 | self.inplanes = num_filters[0] 120 | self.conv1 = nn.Conv2d(3, num_filters[0], kernel_size=3, padding=1, 121 | bias=False) 122 | self.bn1 = nn.BatchNorm2d(num_filters[0]) 123 | self.relu = nn.ReLU(inplace=True) 124 | self.layer1 = self._make_layer(block, num_filters[1], n) 125 | self.layer2 = self._make_layer(block, num_filters[2], n, stride=2) 126 | self.layer3 = self._make_layer(block, num_filters[3], n, stride=2) 127 | self.avgpool = nn.AvgPool2d(8) 128 | self.fc = nn.Linear(num_filters[3] * block.expansion, num_classes) 129 | 130 | for m in self.modules(): 131 | if isinstance(m, nn.Conv2d): 132 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 133 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 134 | nn.init.constant_(m.weight, 1) 135 | nn.init.constant_(m.bias, 0) 136 | 137 | def _make_layer(self, block, planes, blocks, stride=1): 138 | downsample = None 139 | if stride != 1 or self.inplanes != planes * block.expansion: 140 | downsample = nn.Sequential( 141 | nn.Conv2d(self.inplanes, planes * block.expansion, 142 | kernel_size=1, stride=stride, bias=False), 143 | nn.BatchNorm2d(planes * block.expansion), 144 | ) 145 | 146 | layers = list([]) 147 | layers.append(block(self.inplanes, planes, stride, downsample, is_last=(blocks == 1))) 148 | self.inplanes = planes * block.expansion 149 | for i in range(1, blocks): 150 | layers.append(block(self.inplanes, planes, is_last=(i == blocks-1))) 151 | 152 | return nn.Sequential(*layers) 153 | 154 | def get_feat_modules(self): 155 | feat_m = nn.ModuleList([]) 156 | feat_m.append(self.conv1) 157 | feat_m.append(self.bn1) 158 | feat_m.append(self.relu) 159 | feat_m.append(self.layer1) 160 | feat_m.append(self.layer2) 161 | feat_m.append(self.layer3) 162 | return feat_m 163 | 164 | def get_bn_before_relu(self): 165 | if isinstance(self.layer1[0], Bottleneck): 166 | bn1 = self.layer1[-1].bn3 167 | bn2 = self.layer2[-1].bn3 168 | bn3 = self.layer3[-1].bn3 169 | elif isinstance(self.layer1[0], BasicBlock): 170 | bn1 = self.layer1[-1].bn2 171 | bn2 = self.layer2[-1].bn2 172 | bn3 = self.layer3[-1].bn2 173 | else: 174 | raise NotImplementedError('ResNet unknown block error !!!') 175 | 176 | return [bn1, bn2, bn3] 177 | 178 | def forward(self, x, is_feat=False, preact=False): 179 | x = self.conv1(x) 180 | x = self.bn1(x) 181 | x = self.relu(x) # 32x32 182 | f0 = x 183 | 184 | x, f1_pre = self.layer1(x) # 32x32 185 | f1 = x 186 | x, f2_pre = self.layer2(x) # 16x16 187 | f2 = x 188 | x, f3_pre = self.layer3(x) # 8x8 189 | f3 = x 190 | 191 | x = self.avgpool(x) 192 | x = x.view(x.size(0), -1) 193 | f4 = x 194 | x = self.fc(x) 195 | 196 | if is_feat: 197 | if preact: 198 | return [f0, f1_pre, f2_pre, f3_pre, f4], x 199 | else: 200 | return [f0, f1, f2, f3, f4], x 201 | else: 202 | return x 203 | 204 | 205 | def resnet8(**kwargs): 206 | return ResNet(8, [16, 16, 32, 64], 'basicblock', **kwargs) 207 | 208 | 209 | def resnet14(**kwargs): 210 | return ResNet(14, [16, 16, 32, 64], 'basicblock', **kwargs) 211 | 212 | def resnet20(**kwargs): 213 | return ResNet(20, [16, 16, 32, 64], 'basicblock', **kwargs) 214 | 215 | 216 | def resnet14x05(**kwargs): 217 | return ResNet(14, [8, 8, 16, 32], 'basicblock', **kwargs) 218 | 219 | def resnet20x05(**kwargs): 220 | return ResNet(20, [8, 8, 16, 32], 'basicblock', **kwargs) 221 | 222 | def resnet20x0375(**kwargs): 223 | return ResNet(20, [6, 6, 12, 24], 'basicblock', **kwargs) 224 | 225 | 226 | 227 | def resnet32(**kwargs): 228 | return ResNet(32, [16, 16, 32, 64], 'basicblock', **kwargs) 229 | 230 | 231 | def resnet44(**kwargs): 232 | return ResNet(44, [16, 16, 32, 64], 'basicblock', **kwargs) 233 | 234 | 235 | def resnet56(**kwargs): 236 | return ResNet(56, [16, 16, 32, 64], 'basicblock', **kwargs) 237 | 238 | 239 | def resnet110(**kwargs): 240 | return ResNet(110, [16, 16, 32, 64], 'basicblock', **kwargs) 241 | 242 | 243 | def resnet8x4(**kwargs): 244 | return ResNet(8, [32, 64, 128, 256], 'basicblock', **kwargs) 245 | 246 | 247 | def resnet32x4(**kwargs): 248 | return ResNet(32, [32, 64, 128, 256], 'basicblock', **kwargs) 249 | 250 | 251 | if __name__ == '__main__': 252 | import torch 253 | 254 | x = torch.randn(2, 3, 32, 32) 255 | net = resnet8x4(num_classes=20) 256 | feats, logit = net(x, is_feat=True, preact=True) 257 | 258 | for f in feats: 259 | print(f.shape, f.min().item()) 260 | print(logit.shape) 261 | 262 | for m in net.get_bn_before_relu(): 263 | if isinstance(m, nn.BatchNorm2d): 264 | print('pass') 265 | else: 266 | print('warning') 267 | -------------------------------------------------------------------------------- /models/resnetv2.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | For Pre-activation ResNet, see 'preact_resnet.py'. 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, in_planes, planes, stride=1, is_last=False): 16 | super(BasicBlock, self).__init__() 17 | self.is_last = is_last 18 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.shortcut = nn.Sequential() 24 | if stride != 1 or in_planes != self.expansion * planes: 25 | self.shortcut = nn.Sequential( 26 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 27 | nn.BatchNorm2d(self.expansion * planes) 28 | ) 29 | 30 | def forward(self, x): 31 | out = F.relu(self.bn1(self.conv1(x))) 32 | out = self.bn2(self.conv2(out)) 33 | out += self.shortcut(x) 34 | preact = out 35 | out = F.relu(out) 36 | if self.is_last: 37 | return out, preact 38 | else: 39 | return out 40 | 41 | 42 | class Bottleneck(nn.Module): 43 | expansion = 4 44 | 45 | def __init__(self, in_planes, planes, stride=1, is_last=False): 46 | super(Bottleneck, self).__init__() 47 | self.is_last = is_last 48 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(planes) 50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 51 | self.bn2 = nn.BatchNorm2d(planes) 52 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 53 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 54 | 55 | self.shortcut = nn.Sequential() 56 | if stride != 1 or in_planes != self.expansion * planes: 57 | self.shortcut = nn.Sequential( 58 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 59 | nn.BatchNorm2d(self.expansion * planes) 60 | ) 61 | 62 | def forward(self, x): 63 | out = F.relu(self.bn1(self.conv1(x))) 64 | out = F.relu(self.bn2(self.conv2(out))) 65 | out = self.bn3(self.conv3(out)) 66 | out += self.shortcut(x) 67 | preact = out 68 | out = F.relu(out) 69 | if self.is_last: 70 | return out, preact 71 | else: 72 | return out 73 | 74 | 75 | class ResNet(nn.Module): 76 | def __init__(self, block, num_blocks, num_classes=10, zero_init_residual=False): 77 | super(ResNet, self).__init__() 78 | self.in_planes = 64 79 | 80 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 81 | self.bn1 = nn.BatchNorm2d(64) 82 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 83 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 84 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 85 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 86 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 87 | self.linear = nn.Linear(512 * block.expansion, num_classes) 88 | 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 92 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 93 | nn.init.constant_(m.weight, 1) 94 | nn.init.constant_(m.bias, 0) 95 | 96 | # Zero-initialize the last BN in each residual branch, 97 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 98 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 99 | if zero_init_residual: 100 | for m in self.modules(): 101 | if isinstance(m, Bottleneck): 102 | nn.init.constant_(m.bn3.weight, 0) 103 | elif isinstance(m, BasicBlock): 104 | nn.init.constant_(m.bn2.weight, 0) 105 | 106 | def get_feat_modules(self): 107 | feat_m = nn.ModuleList([]) 108 | feat_m.append(self.conv1) 109 | feat_m.append(self.bn1) 110 | feat_m.append(self.layer1) 111 | feat_m.append(self.layer2) 112 | feat_m.append(self.layer3) 113 | feat_m.append(self.layer4) 114 | return feat_m 115 | 116 | def get_bn_before_relu(self): 117 | if isinstance(self.layer1[0], Bottleneck): 118 | bn1 = self.layer1[-1].bn3 119 | bn2 = self.layer2[-1].bn3 120 | bn3 = self.layer3[-1].bn3 121 | bn4 = self.layer4[-1].bn3 122 | elif isinstance(self.layer1[0], BasicBlock): 123 | bn1 = self.layer1[-1].bn2 124 | bn2 = self.layer2[-1].bn2 125 | bn3 = self.layer3[-1].bn2 126 | bn4 = self.layer4[-1].bn2 127 | else: 128 | raise NotImplementedError('ResNet unknown block error !!!') 129 | 130 | return [bn1, bn2, bn3, bn4] 131 | 132 | def _make_layer(self, block, planes, num_blocks, stride): 133 | strides = [stride] + [1] * (num_blocks - 1) 134 | layers = [] 135 | for i in range(num_blocks): 136 | stride = strides[i] 137 | layers.append(block(self.in_planes, planes, stride, i == num_blocks - 1)) 138 | self.in_planes = planes * block.expansion 139 | return nn.Sequential(*layers) 140 | 141 | def forward(self, x, is_feat=False, preact=False): 142 | out = F.relu(self.bn1(self.conv1(x))) 143 | f0 = out 144 | out, f1_pre = self.layer1(out) 145 | f1 = out 146 | out, f2_pre = self.layer2(out) 147 | f2 = out 148 | out, f3_pre = self.layer3(out) 149 | f3 = out 150 | out, f4_pre = self.layer4(out) 151 | f4 = out 152 | out = self.avgpool(out) 153 | out = out.view(out.size(0), -1) 154 | f5 = out 155 | out = self.linear(out) 156 | if is_feat: 157 | if preact: 158 | return [[f0, f1_pre, f2_pre, f3_pre, f4_pre, f5], out] 159 | else: 160 | return [f0, f1, f2, f3, f4, f5], out 161 | else: 162 | return out 163 | 164 | 165 | def ResNet18(**kwargs): 166 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 167 | 168 | 169 | def ResNet34(**kwargs): 170 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 171 | 172 | 173 | def ResNet50(**kwargs): 174 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 175 | 176 | 177 | def ResNet101(**kwargs): 178 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 179 | 180 | 181 | def ResNet152(**kwargs): 182 | return ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 183 | 184 | 185 | if __name__ == '__main__': 186 | net = ResNet18(num_classes=100) 187 | x = torch.randn(2, 3, 32, 32) 188 | feats, logit = net(x, is_feat=True, preact=True) 189 | 190 | for f in feats: 191 | print(f.shape, f.min().item()) 192 | print(logit.shape) 193 | 194 | for m in net.get_bn_before_relu(): 195 | if isinstance(m, nn.BatchNorm2d): 196 | print('pass') 197 | else: 198 | print('warning') 199 | -------------------------------------------------------------------------------- /models/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch.nn as nn 4 | import math 5 | 6 | 7 | class Paraphraser(nn.Module): 8 | """Paraphrasing Complex Network: Network Compression via Factor Transfer""" 9 | def __init__(self, t_shape, k=0.5, use_bn=False): 10 | super(Paraphraser, self).__init__() 11 | in_channel = t_shape[1] 12 | out_channel = int(t_shape[1] * k) 13 | self.encoder = nn.Sequential( 14 | nn.Conv2d(in_channel, in_channel, 3, 1, 1), 15 | nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(), 16 | nn.LeakyReLU(0.1, inplace=True), 17 | nn.Conv2d(in_channel, out_channel, 3, 1, 1), 18 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), 19 | nn.LeakyReLU(0.1, inplace=True), 20 | nn.Conv2d(out_channel, out_channel, 3, 1, 1), 21 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), 22 | nn.LeakyReLU(0.1, inplace=True), 23 | ) 24 | self.decoder = nn.Sequential( 25 | nn.ConvTranspose2d(out_channel, out_channel, 3, 1, 1), 26 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), 27 | nn.LeakyReLU(0.1, inplace=True), 28 | nn.ConvTranspose2d(out_channel, in_channel, 3, 1, 1), 29 | nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(), 30 | nn.LeakyReLU(0.1, inplace=True), 31 | nn.ConvTranspose2d(in_channel, in_channel, 3, 1, 1), 32 | nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(), 33 | nn.LeakyReLU(0.1, inplace=True), 34 | ) 35 | 36 | def forward(self, f_s, is_factor=False): 37 | factor = self.encoder(f_s) 38 | if is_factor: 39 | return factor 40 | rec = self.decoder(factor) 41 | return factor, rec 42 | 43 | 44 | class Translator(nn.Module): 45 | def __init__(self, s_shape, t_shape, k=0.5, use_bn=True): 46 | super(Translator, self).__init__() 47 | in_channel = s_shape[1] 48 | out_channel = int(t_shape[1] * k) 49 | self.encoder = nn.Sequential( 50 | nn.Conv2d(in_channel, in_channel, 3, 1, 1), 51 | nn.BatchNorm2d(in_channel) if use_bn else nn.Sequential(), 52 | nn.LeakyReLU(0.1, inplace=True), 53 | nn.Conv2d(in_channel, out_channel, 3, 1, 1), 54 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), 55 | nn.LeakyReLU(0.1, inplace=True), 56 | nn.Conv2d(out_channel, out_channel, 3, 1, 1), 57 | nn.BatchNorm2d(out_channel) if use_bn else nn.Sequential(), 58 | nn.LeakyReLU(0.1, inplace=True), 59 | ) 60 | 61 | def forward(self, f_s): 62 | return self.encoder(f_s) 63 | 64 | 65 | class Connector(nn.Module): 66 | """Connect for Knowledge Transfer via Distillation of Activation Boundaries Formed by Hidden Neurons""" 67 | def __init__(self, s_shapes, t_shapes): 68 | super(Connector, self).__init__() 69 | self.s_shapes = s_shapes 70 | self.t_shapes = t_shapes 71 | 72 | self.connectors = nn.ModuleList(self._make_conenctors(s_shapes, t_shapes)) 73 | 74 | @staticmethod 75 | def _make_conenctors(s_shapes, t_shapes): 76 | assert len(s_shapes) == len(t_shapes), 'unequal length of feat list' 77 | connectors = [] 78 | for s, t in zip(s_shapes, t_shapes): 79 | if s[1] == t[1] and s[2] == t[2]: 80 | connectors.append(nn.Sequential()) 81 | else: 82 | connectors.append(ConvReg(s, t, use_relu=False)) 83 | return connectors 84 | 85 | def forward(self, g_s): 86 | out = [] 87 | for i in range(len(g_s)): 88 | out.append(self.connectors[i](g_s[i])) 89 | 90 | return out 91 | 92 | 93 | class ConnectorV2(nn.Module): 94 | """A Comprehensive Overhaul of Feature Distillation (ICCV 2019)""" 95 | def __init__(self, s_shapes, t_shapes): 96 | super(ConnectorV2, self).__init__() 97 | self.s_shapes = s_shapes 98 | self.t_shapes = t_shapes 99 | 100 | self.connectors = nn.ModuleList(self._make_conenctors(s_shapes, t_shapes)) 101 | 102 | def _make_conenctors(self, s_shapes, t_shapes): 103 | assert len(s_shapes) == len(t_shapes), 'unequal length of feat list' 104 | t_channels = [t[1] for t in t_shapes] 105 | s_channels = [s[1] for s in s_shapes] 106 | connectors = nn.ModuleList([self._build_feature_connector(t, s) 107 | for t, s in zip(t_channels, s_channels)]) 108 | return connectors 109 | 110 | @staticmethod 111 | def _build_feature_connector(t_channel, s_channel): 112 | C = [nn.Conv2d(s_channel, t_channel, kernel_size=1, stride=1, padding=0, bias=False), 113 | nn.BatchNorm2d(t_channel)] 114 | for m in C: 115 | if isinstance(m, nn.Conv2d): 116 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 117 | m.weight.data.normal_(0, math.sqrt(2. / n)) 118 | elif isinstance(m, nn.BatchNorm2d): 119 | m.weight.data.fill_(1) 120 | m.bias.data.zero_() 121 | return nn.Sequential(*C) 122 | 123 | def forward(self, g_s): 124 | out = [] 125 | for i in range(len(g_s)): 126 | out.append(self.connectors[i](g_s[i])) 127 | 128 | return out 129 | 130 | 131 | class ConvReg(nn.Module): 132 | """Convolutional regression for FitNet""" 133 | def __init__(self, s_shape, t_shape, use_relu=True): 134 | super(ConvReg, self).__init__() 135 | self.use_relu = use_relu 136 | s_N, s_C, s_H, s_W = s_shape 137 | t_N, t_C, t_H, t_W = t_shape 138 | if s_H == 2 * t_H: 139 | self.conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1) 140 | elif s_H * 2 == t_H: 141 | self.conv = nn.ConvTranspose2d(s_C, t_C, kernel_size=4, stride=2, padding=1) 142 | elif s_H >= t_H: 143 | self.conv = nn.Conv2d(s_C, t_C, kernel_size=(1+s_H-t_H, 1+s_W-t_W)) 144 | else: 145 | raise NotImplemented('student size {}, teacher size {}'.format(s_H, t_H)) 146 | self.bn = nn.BatchNorm2d(t_C) 147 | self.relu = nn.ReLU(inplace=True) 148 | 149 | def forward(self, x): 150 | x = self.conv(x) 151 | if self.use_relu: 152 | return self.relu(self.bn(x)) 153 | else: 154 | return self.bn(x) 155 | 156 | 157 | class Regress(nn.Module): 158 | """Simple Linear Regression for hints""" 159 | def __init__(self, dim_in=1024, dim_out=1024): 160 | super(Regress, self).__init__() 161 | self.linear = nn.Linear(dim_in, dim_out) 162 | self.relu = nn.ReLU(inplace=True) 163 | 164 | def forward(self, x): 165 | x = x.view(x.shape[0], -1) 166 | x = self.linear(x) 167 | x = self.relu(x) 168 | return x 169 | 170 | 171 | class Embed(nn.Module): 172 | """Embedding module""" 173 | def __init__(self, dim_in=1024, dim_out=128): 174 | super(Embed, self).__init__() 175 | self.linear = nn.Linear(dim_in, dim_out) 176 | self.l2norm = Normalize(2) 177 | 178 | def forward(self, x): 179 | x = x.view(x.shape[0], -1) 180 | x = self.linear(x) 181 | x = self.l2norm(x) 182 | return x 183 | 184 | 185 | class LinearEmbed(nn.Module): 186 | """Linear Embedding""" 187 | def __init__(self, dim_in=1024, dim_out=128): 188 | super(LinearEmbed, self).__init__() 189 | self.linear = nn.Linear(dim_in, dim_out) 190 | 191 | def forward(self, x): 192 | x = x.view(x.shape[0], -1) 193 | x = self.linear(x) 194 | return x 195 | 196 | 197 | class MLPEmbed(nn.Module): 198 | """non-linear embed by MLP""" 199 | def __init__(self, dim_in=1024, dim_out=128): 200 | super(MLPEmbed, self).__init__() 201 | self.linear1 = nn.Linear(dim_in, 2 * dim_out) 202 | self.relu = nn.ReLU(inplace=True) 203 | self.linear2 = nn.Linear(2 * dim_out, dim_out) 204 | self.l2norm = Normalize(2) 205 | 206 | def forward(self, x): 207 | x = x.view(x.shape[0], -1) 208 | x = self.relu(self.linear1(x)) 209 | x = self.l2norm(self.linear2(x)) 210 | return x 211 | 212 | 213 | class Normalize(nn.Module): 214 | """normalization layer""" 215 | def __init__(self, power=2): 216 | super(Normalize, self).__init__() 217 | self.power = power 218 | 219 | def forward(self, x): 220 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 221 | out = x.div(norm) 222 | return out 223 | 224 | 225 | class Flatten(nn.Module): 226 | """flatten module""" 227 | def __init__(self): 228 | super(Flatten, self).__init__() 229 | 230 | def forward(self, feat): 231 | return feat.view(feat.size(0), -1) 232 | 233 | 234 | class PoolEmbed(nn.Module): 235 | """pool and embed""" 236 | def __init__(self, layer=0, dim_out=128, pool_type='avg'): 237 | super().__init__() 238 | if layer == 0: 239 | pool_size = 8 240 | nChannels = 16 241 | elif layer == 1: 242 | pool_size = 8 243 | nChannels = 16 244 | elif layer == 2: 245 | pool_size = 6 246 | nChannels = 32 247 | elif layer == 3: 248 | pool_size = 4 249 | nChannels = 64 250 | elif layer == 4: 251 | pool_size = 1 252 | nChannels = 64 253 | else: 254 | raise NotImplementedError('layer not supported: {}'.format(layer)) 255 | 256 | self.embed = nn.Sequential() 257 | if layer <= 3: 258 | if pool_type == 'max': 259 | self.embed.add_module('MaxPool', nn.AdaptiveMaxPool2d((pool_size, pool_size))) 260 | elif pool_type == 'avg': 261 | self.embed.add_module('AvgPool', nn.AdaptiveAvgPool2d((pool_size, pool_size))) 262 | 263 | self.embed.add_module('Flatten', Flatten()) 264 | self.embed.add_module('Linear', nn.Linear(nChannels*pool_size*pool_size, dim_out)) 265 | self.embed.add_module('Normalize', Normalize(2)) 266 | 267 | def forward(self, x): 268 | return self.embed(x) 269 | 270 | 271 | if __name__ == '__main__': 272 | import torch 273 | 274 | g_s = [ 275 | torch.randn(2, 16, 16, 16), 276 | torch.randn(2, 32, 8, 8), 277 | torch.randn(2, 64, 4, 4), 278 | ] 279 | g_t = [ 280 | torch.randn(2, 32, 16, 16), 281 | torch.randn(2, 64, 8, 8), 282 | torch.randn(2, 128, 4, 4), 283 | ] 284 | s_shapes = [s.shape for s in g_s] 285 | t_shapes = [t.shape for t in g_t] 286 | 287 | net = ConnectorV2(s_shapes, t_shapes) 288 | out = net(g_s) 289 | for f in out: 290 | print(f.shape) 291 | -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | '''VGG for CIFAR10. FC layers are removed. 2 | (c) YANG, Wei 3 | ''' 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import math 7 | 8 | 9 | __all__ = [ 10 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 11 | 'vgg19_bn', 'vgg19', 12 | ] 13 | 14 | 15 | model_urls = { 16 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 17 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 18 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 19 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 20 | } 21 | 22 | 23 | class VGG(nn.Module): 24 | 25 | def __init__(self, cfg, batch_norm=False, num_classes=1000): 26 | super(VGG, self).__init__() 27 | self.block0 = self._make_layers(cfg[0], batch_norm, 3) 28 | self.block1 = self._make_layers(cfg[1], batch_norm, cfg[0][-1]) 29 | self.block2 = self._make_layers(cfg[2], batch_norm, cfg[1][-1]) 30 | self.block3 = self._make_layers(cfg[3], batch_norm, cfg[2][-1]) 31 | self.block4 = self._make_layers(cfg[4], batch_norm, cfg[3][-1]) 32 | 33 | self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2) 34 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 35 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 36 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 37 | self.pool4 = nn.AdaptiveAvgPool2d((1, 1)) 38 | # self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 39 | 40 | self.classifier = nn.Linear(512, num_classes) 41 | self._initialize_weights() 42 | 43 | def get_feat_modules(self): 44 | feat_m = nn.ModuleList([]) 45 | feat_m.append(self.block0) 46 | feat_m.append(self.pool0) 47 | feat_m.append(self.block1) 48 | feat_m.append(self.pool1) 49 | feat_m.append(self.block2) 50 | feat_m.append(self.pool2) 51 | feat_m.append(self.block3) 52 | feat_m.append(self.pool3) 53 | feat_m.append(self.block4) 54 | feat_m.append(self.pool4) 55 | return feat_m 56 | 57 | def get_bn_before_relu(self): 58 | bn1 = self.block1[-1] 59 | bn2 = self.block2[-1] 60 | bn3 = self.block3[-1] 61 | bn4 = self.block4[-1] 62 | return [bn1, bn2, bn3, bn4] 63 | 64 | def forward(self, x, is_feat=False, preact=False): 65 | h = x.shape[2] 66 | x = F.relu(self.block0(x)) 67 | f0 = x 68 | x = self.pool0(x) 69 | x = self.block1(x) 70 | f1_pre = x 71 | x = F.relu(x) 72 | f1 = x 73 | x = self.pool1(x) 74 | x = self.block2(x) 75 | f2_pre = x 76 | x = F.relu(x) 77 | f2 = x 78 | x = self.pool2(x) 79 | x = self.block3(x) 80 | f3_pre = x 81 | x = F.relu(x) 82 | f3 = x 83 | if h == 64: 84 | x = self.pool3(x) 85 | x = self.block4(x) 86 | f4_pre = x 87 | x = F.relu(x) 88 | f4 = x 89 | x = self.pool4(x) 90 | x = x.view(x.size(0), -1) 91 | f5 = x 92 | x = self.classifier(x) 93 | 94 | if is_feat: 95 | if preact: 96 | return [f0, f1_pre, f2_pre, f3_pre, f4_pre, f5], x 97 | else: 98 | return [f0, f1, f2, f3, f4, f5], x 99 | else: 100 | return x 101 | 102 | @staticmethod 103 | def _make_layers(cfg, batch_norm=False, in_channels=3): 104 | layers = [] 105 | for v in cfg: 106 | if v == 'M': 107 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 108 | else: 109 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 110 | if batch_norm: 111 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 112 | else: 113 | layers += [conv2d, nn.ReLU(inplace=True)] 114 | in_channels = v 115 | layers = layers[:-1] 116 | return nn.Sequential(*layers) 117 | 118 | def _initialize_weights(self): 119 | for m in self.modules(): 120 | if isinstance(m, nn.Conv2d): 121 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 122 | m.weight.data.normal_(0, math.sqrt(2. / n)) 123 | if m.bias is not None: 124 | m.bias.data.zero_() 125 | elif isinstance(m, nn.BatchNorm2d): 126 | m.weight.data.fill_(1) 127 | m.bias.data.zero_() 128 | elif isinstance(m, nn.Linear): 129 | n = m.weight.size(1) 130 | m.weight.data.normal_(0, 0.01) 131 | m.bias.data.zero_() 132 | 133 | 134 | cfg = { 135 | 'A': [[64], [128], [256, 256], [512, 512], [512, 512]], 136 | 'B': [[64, 64], [128, 128], [256, 256], [512, 512], [512, 512]], 137 | 'D': [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]], 138 | 'E': [[64, 64], [128, 128], [256, 256, 256, 256], [512, 512, 512, 512], [512, 512, 512, 512]], 139 | 'S': [[64], [128], [256], [512], [512]], 140 | } 141 | 142 | 143 | def vgg8(**kwargs): 144 | """VGG 8-layer model (configuration "S") 145 | Args: 146 | pretrained (bool): If True, returns a model pre-trained on ImageNet 147 | """ 148 | model = VGG(cfg['S'], **kwargs) 149 | return model 150 | 151 | 152 | def vgg8_bn(**kwargs): 153 | """VGG 8-layer model (configuration "S") 154 | Args: 155 | pretrained (bool): If True, returns a model pre-trained on ImageNet 156 | """ 157 | model = VGG(cfg['S'], batch_norm=True, **kwargs) 158 | return model 159 | 160 | 161 | def vgg11(**kwargs): 162 | """VGG 11-layer model (configuration "A") 163 | Args: 164 | pretrained (bool): If True, returns a model pre-trained on ImageNet 165 | """ 166 | model = VGG(cfg['A'], **kwargs) 167 | return model 168 | 169 | 170 | def vgg11_bn(**kwargs): 171 | """VGG 11-layer model (configuration "A") with batch normalization""" 172 | model = VGG(cfg['A'], batch_norm=True, **kwargs) 173 | return model 174 | 175 | 176 | def vgg13(**kwargs): 177 | """VGG 13-layer model (configuration "B") 178 | Args: 179 | pretrained (bool): If True, returns a model pre-trained on ImageNet 180 | """ 181 | model = VGG(cfg['B'], **kwargs) 182 | return model 183 | 184 | 185 | def vgg13_bn(**kwargs): 186 | """VGG 13-layer model (configuration "B") with batch normalization""" 187 | model = VGG(cfg['B'], batch_norm=True, **kwargs) 188 | return model 189 | 190 | 191 | def vgg16(**kwargs): 192 | """VGG 16-layer model (configuration "D") 193 | Args: 194 | pretrained (bool): If True, returns a model pre-trained on ImageNet 195 | """ 196 | model = VGG(cfg['D'], **kwargs) 197 | return model 198 | 199 | 200 | def vgg16_bn(**kwargs): 201 | """VGG 16-layer model (configuration "D") with batch normalization""" 202 | model = VGG(cfg['D'], batch_norm=True, **kwargs) 203 | return model 204 | 205 | 206 | def vgg19(**kwargs): 207 | """VGG 19-layer model (configuration "E") 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on ImageNet 210 | """ 211 | model = VGG(cfg['E'], **kwargs) 212 | return model 213 | 214 | 215 | def vgg19_bn(**kwargs): 216 | """VGG 19-layer model (configuration 'E') with batch normalization""" 217 | model = VGG(cfg['E'], batch_norm=True, **kwargs) 218 | return model 219 | 220 | 221 | if __name__ == '__main__': 222 | import torch 223 | 224 | x = torch.randn(2, 3, 32, 32) 225 | net = vgg19_bn(num_classes=100) 226 | feats, logit = net(x, is_feat=True, preact=True) 227 | 228 | for f in feats: 229 | print(f.shape, f.min().item()) 230 | print(logit.shape) 231 | 232 | for m in net.get_bn_before_relu(): 233 | if isinstance(m, nn.BatchNorm2d): 234 | print('pass') 235 | else: 236 | print('warning') 237 | -------------------------------------------------------------------------------- /models/wrn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | """ 7 | Original Author: Wei Yang 8 | """ 9 | 10 | __all__ = ['wrn'] 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 15 | super(BasicBlock, self).__init__() 16 | self.bn1 = nn.BatchNorm2d(in_planes) 17 | self.relu1 = nn.ReLU(inplace=True) 18 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(out_planes) 21 | self.relu2 = nn.ReLU(inplace=True) 22 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 23 | padding=1, bias=False) 24 | self.droprate = dropRate 25 | self.equalInOut = (in_planes == out_planes) 26 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 27 | padding=0, bias=False) or None 28 | 29 | def forward(self, x): 30 | if not self.equalInOut: 31 | x = self.relu1(self.bn1(x)) 32 | else: 33 | out = self.relu1(self.bn1(x)) 34 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 35 | if self.droprate > 0: 36 | out = F.dropout(out, p=self.droprate, training=self.training) 37 | out = self.conv2(out) 38 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 39 | 40 | 41 | class NetworkBlock(nn.Module): 42 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 43 | super(NetworkBlock, self).__init__() 44 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 45 | 46 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 47 | layers = [] 48 | for i in range(nb_layers): 49 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 50 | return nn.Sequential(*layers) 51 | 52 | def forward(self, x): 53 | return self.layer(x) 54 | 55 | 56 | class WideResNet(nn.Module): 57 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 58 | super(WideResNet, self).__init__() 59 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 60 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' 61 | n = (depth - 4) // 6 62 | block = BasicBlock 63 | # 1st conv before any network block 64 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 65 | padding=1, bias=False) 66 | # 1st block 67 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 68 | # 2nd block 69 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 70 | # 3rd block 71 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 72 | # global average pooling and classifier 73 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.fc = nn.Linear(nChannels[3], num_classes) 76 | self.nChannels = nChannels[3] 77 | 78 | for m in self.modules(): 79 | if isinstance(m, nn.Conv2d): 80 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 81 | m.weight.data.normal_(0, math.sqrt(2. / n)) 82 | elif isinstance(m, nn.BatchNorm2d): 83 | m.weight.data.fill_(1) 84 | m.bias.data.zero_() 85 | elif isinstance(m, nn.Linear): 86 | m.bias.data.zero_() 87 | 88 | def get_feat_modules(self): 89 | feat_m = nn.ModuleList([]) 90 | feat_m.append(self.conv1) 91 | feat_m.append(self.block1) 92 | feat_m.append(self.block2) 93 | feat_m.append(self.block3) 94 | return feat_m 95 | 96 | def get_bn_before_relu(self): 97 | bn1 = self.block2.layer[0].bn1 98 | bn2 = self.block3.layer[0].bn1 99 | bn3 = self.bn1 100 | 101 | return [bn1, bn2, bn3] 102 | 103 | def forward(self, x, is_feat=False, preact=False): 104 | out = self.conv1(x) 105 | f0 = out 106 | out = self.block1(out) 107 | f1 = out 108 | out = self.block2(out) 109 | f2 = out 110 | out = self.block3(out) 111 | f3 = out 112 | out = self.relu(self.bn1(out)) 113 | out = F.avg_pool2d(out, 8) 114 | out = out.view(-1, self.nChannels) 115 | f4 = out 116 | out = self.fc(out) 117 | if is_feat: 118 | if preact: 119 | f1 = self.block2.layer[0].bn1(f1) 120 | f2 = self.block3.layer[0].bn1(f2) 121 | f3 = self.bn1(f3) 122 | return [f0, f1, f2, f3, f4], out 123 | else: 124 | return out 125 | 126 | 127 | def wrn(**kwargs): 128 | """ 129 | Constructs a Wide Residual Networks. 130 | """ 131 | model = WideResNet(**kwargs) 132 | return model 133 | 134 | 135 | def wrn_40_2(**kwargs): 136 | model = WideResNet(depth=40, widen_factor=2, **kwargs) 137 | return model 138 | 139 | 140 | def wrn_40_1(**kwargs): 141 | model = WideResNet(depth=40, widen_factor=1, **kwargs) 142 | return model 143 | 144 | 145 | def wrn_16_2(**kwargs): 146 | model = WideResNet(depth=16, widen_factor=2, **kwargs) 147 | return model 148 | 149 | 150 | def wrn_16_1(**kwargs): 151 | model = WideResNet(depth=16, widen_factor=1, **kwargs) 152 | return model 153 | 154 | 155 | if __name__ == '__main__': 156 | import torch 157 | 158 | x = torch.randn(2, 3, 32, 32) 159 | net = wrn_40_2(num_classes=100) 160 | feats, logit = net(x, is_feat=True, preact=True) 161 | 162 | for f in feats: 163 | print(f.shape, f.min().item()) 164 | print(logit.shape) 165 | 166 | for m in net.get_bn_before_relu(): 167 | if isinstance(m, nn.BatchNorm2d): 168 | print('pass') 169 | else: 170 | print('warning') 171 | -------------------------------------------------------------------------------- /student.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import argparse 4 | import time 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from torch.optim.lr_scheduler import MultiStepLR 12 | from torch.utils.data import DataLoader 13 | 14 | import torchvision.transforms as transforms 15 | from tensorboardX import SummaryWriter 16 | 17 | from utils import AverageMeter, accuracy 18 | from wrapper import wrapper 19 | from cifar import CIFAR100 20 | 21 | from models import model_dict 22 | 23 | torch.backends.cudnn.benchmark = True 24 | 25 | parser = argparse.ArgumentParser(description='train SSKD student network.') 26 | parser.add_argument('--epoch', type=int, default=240) 27 | parser.add_argument('--t-epoch', type=int, default=60) 28 | parser.add_argument('--batch-size', type=int, default=64) 29 | 30 | parser.add_argument('--lr', type=float, default=0.05) 31 | parser.add_argument('--t-lr', type=float, default=0.05) 32 | parser.add_argument('--momentum', type=float, default=0.9) 33 | parser.add_argument('--weight-decay', type=float, default=5e-4) 34 | parser.add_argument('--gamma', type=float, default=0.1) 35 | parser.add_argument('--milestones', type=int, nargs='+', default=[150,180,210]) 36 | parser.add_argument('--t-milestones', type=int, nargs='+', default=[30,45]) 37 | 38 | parser.add_argument('--save-interval', type=int, default=40) 39 | parser.add_argument('--ce-weight', type=float, default=0.1) # cross-entropy 40 | parser.add_argument('--kd-weight', type=float, default=0.9) # knowledge distillation 41 | parser.add_argument('--tf-weight', type=float, default=2.7) # transformation 42 | parser.add_argument('--ss-weight', type=float, default=10.0) # self-supervision 43 | 44 | parser.add_argument('--kd-T', type=float, default=4.0) # temperature in KD 45 | parser.add_argument('--tf-T', type=float, default=4.0) # temperature in LT 46 | parser.add_argument('--ss-T', type=float, default=0.5) # temperature in SS 47 | 48 | parser.add_argument('--ratio-tf', type=float, default=1.0) # keep how many wrong predictions of LT 49 | parser.add_argument('--ratio-ss', type=float, default=0.75) # keep how many wrong predictions of SS 50 | parser.add_argument('--s-arch', type=str) # student architecture 51 | parser.add_argument('--t-path', type=str) # teacher checkpoint path 52 | 53 | parser.add_argument('--seed', type=int, default=0) 54 | parser.add_argument('--gpu-id', type=int, default=0) 55 | 56 | args = parser.parse_args() 57 | torch.manual_seed(args.seed) 58 | torch.cuda.manual_seed(args.seed) 59 | np.random.seed(args.seed) 60 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) 61 | 62 | 63 | t_name = osp.abspath(args.t_path).split('/')[-1] 64 | t_arch = '_'.join(t_name.split('_')[1:-1]) 65 | exp_name = 'sskd_student_{}_weight{}+{}+{}+{}_T{}+{}+{}_ratio{}+{}_seed{}_{}'.format(\ 66 | args.s_arch, \ 67 | args.ce_weight, args.kd_weight, args.tf_weight, args.ss_weight, \ 68 | args.kd_T, args.tf_T, args.ss_T, \ 69 | args.ratio_tf, args.ratio_ss, \ 70 | args.seed, t_name) 71 | exp_path = './experiments/{}'.format(exp_name) 72 | os.makedirs(exp_path, exist_ok=True) 73 | 74 | transform_train = transforms.Compose([ 75 | transforms.RandomCrop(32, padding=4), 76 | transforms.ToTensor(), 77 | transforms.Normalize(mean=[0.5071, 0.4866, 0.4409], std=[0.2675, 0.2565, 0.2761]), 78 | ]) 79 | transform_test = transforms.Compose([ 80 | transforms.ToTensor(), 81 | transforms.Normalize(mean=[0.5071, 0.4866, 0.4409], std=[0.2675, 0.2565, 0.2761]), 82 | ]) 83 | 84 | trainset = CIFAR100('./data', train=True, transform=transform_train) 85 | valset = CIFAR100('./data', train=False, transform=transform_test) 86 | 87 | train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=False) 88 | val_loader = DataLoader(valset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=False) 89 | 90 | ckpt_path = osp.join(args.t_path, 'ckpt/best.pth') 91 | t_model = model_dict[t_arch](num_classes=100).cuda() 92 | state_dict = torch.load(ckpt_path)['state_dict'] 93 | t_model.load_state_dict(state_dict) 94 | t_model = wrapper(module=t_model).cuda() 95 | 96 | t_optimizer = optim.SGD([{'params':t_model.backbone.parameters(), 'lr':0.0}, 97 | {'params':t_model.proj_head.parameters(), 'lr':args.t_lr}], 98 | momentum=args.momentum, weight_decay=args.weight_decay) 99 | t_model.eval() 100 | t_scheduler = MultiStepLR(t_optimizer, milestones=args.t_milestones, gamma=args.gamma) 101 | 102 | logger = SummaryWriter(osp.join(exp_path, 'events')) 103 | 104 | acc_record = AverageMeter() 105 | loss_record = AverageMeter() 106 | start = time.time() 107 | for x, target in val_loader: 108 | 109 | x = x[:,0,:,:,:].cuda() 110 | target = target.cuda() 111 | with torch.no_grad(): 112 | output, _, feat = t_model(x) 113 | loss = F.cross_entropy(output, target) 114 | 115 | batch_acc = accuracy(output, target, topk=(1,))[0] 116 | acc_record.update(batch_acc.item(), x.size(0)) 117 | loss_record.update(loss.item(), x.size(0)) 118 | 119 | run_time = time.time() - start 120 | info = 'teacher cls_acc:{:.2f}\n'.format(acc_record.avg) 121 | print(info) 122 | 123 | # train ssp_head 124 | for epoch in range(args.t_epoch): 125 | 126 | t_model.eval() 127 | loss_record = AverageMeter() 128 | acc_record = AverageMeter() 129 | 130 | start = time.time() 131 | for x, _ in train_loader: 132 | 133 | t_optimizer.zero_grad() 134 | 135 | x = x.cuda() 136 | c,h,w = x.size()[-3:] 137 | x = x.view(-1, c, h, w) 138 | 139 | _, rep, feat = t_model(x, bb_grad=False) 140 | batch = int(x.size(0) / 4) 141 | nor_index = (torch.arange(4*batch) % 4 == 0).cuda() 142 | aug_index = (torch.arange(4*batch) % 4 != 0).cuda() 143 | 144 | nor_rep = rep[nor_index] 145 | aug_rep = rep[aug_index] 146 | nor_rep = nor_rep.unsqueeze(2).expand(-1,-1,3*batch).transpose(0,2) 147 | aug_rep = aug_rep.unsqueeze(2).expand(-1,-1,1*batch) 148 | simi = F.cosine_similarity(aug_rep, nor_rep, dim=1) 149 | target = torch.arange(batch).unsqueeze(1).expand(-1,3).contiguous().view(-1).long().cuda() 150 | loss = F.cross_entropy(simi, target) 151 | 152 | loss.backward() 153 | t_optimizer.step() 154 | 155 | batch_acc = accuracy(simi, target, topk=(1,))[0] 156 | loss_record.update(loss.item(), 3*batch) 157 | acc_record.update(batch_acc.item(), 3*batch) 158 | 159 | logger.add_scalar('train/teacher_ssp_loss', loss_record.avg, epoch+1) 160 | logger.add_scalar('train/teacher_ssp_acc', acc_record.avg, epoch+1) 161 | 162 | run_time = time.time() - start 163 | info = 'teacher_train_Epoch:{:03d}/{:03d}\t run_time:{:.3f}\t ssp_loss:{:.3f}\t ssp_acc:{:.2f}\t'.format( 164 | epoch+1, args.t_epoch, run_time, loss_record.avg, acc_record.avg) 165 | print(info) 166 | 167 | t_model.eval() 168 | acc_record = AverageMeter() 169 | loss_record = AverageMeter() 170 | start = time.time() 171 | for x, _ in val_loader: 172 | 173 | x = x.cuda() 174 | c,h,w = x.size()[-3:] 175 | x = x.view(-1, c, h, w) 176 | 177 | with torch.no_grad(): 178 | _, rep, feat = t_model(x) 179 | batch = int(x.size(0) / 4) 180 | nor_index = (torch.arange(4*batch) % 4 == 0).cuda() 181 | aug_index = (torch.arange(4*batch) % 4 != 0).cuda() 182 | 183 | nor_rep = rep[nor_index] 184 | aug_rep = rep[aug_index] 185 | nor_rep = nor_rep.unsqueeze(2).expand(-1,-1,3*batch).transpose(0,2) 186 | aug_rep = aug_rep.unsqueeze(2).expand(-1,-1,1*batch) 187 | simi = F.cosine_similarity(aug_rep, nor_rep, dim=1) 188 | target = torch.arange(batch).unsqueeze(1).expand(-1,3).contiguous().view(-1).long().cuda() 189 | loss = F.cross_entropy(simi, target) 190 | 191 | batch_acc = accuracy(simi, target, topk=(1,))[0] 192 | acc_record.update(batch_acc.item(),3*batch) 193 | loss_record.update(loss.item(), 3*batch) 194 | 195 | run_time = time.time() - start 196 | logger.add_scalar('val/teacher_ssp_loss', loss_record.avg, epoch+1) 197 | logger.add_scalar('val/teacher_ssp_acc', acc_record.avg, epoch+1) 198 | 199 | info = 'ssp_test_Epoch:{:03d}/{:03d}\t run_time:{:.2f}\t ssp_loss:{:.3f}\t ssp_acc:{:.2f}\n'.format( 200 | epoch+1, args.t_epoch, run_time, loss_record.avg, acc_record.avg) 201 | print(info) 202 | 203 | t_scheduler.step() 204 | 205 | 206 | name = osp.join(exp_path, 'ckpt/teacher.pth') 207 | os.makedirs(osp.dirname(name), exist_ok=True) 208 | torch.save(t_model.state_dict(), name) 209 | 210 | 211 | s_model = model_dict[args.s_arch](num_classes=100) 212 | s_model = wrapper(module=s_model).cuda() 213 | optimizer = optim.SGD(s_model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 214 | scheduler = MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma) 215 | 216 | best_acc = 0 217 | for epoch in range(args.epoch): 218 | 219 | # train 220 | s_model.train() 221 | loss1_record = AverageMeter() 222 | loss2_record = AverageMeter() 223 | loss3_record = AverageMeter() 224 | loss4_record = AverageMeter() 225 | cls_acc_record = AverageMeter() 226 | ssp_acc_record = AverageMeter() 227 | 228 | start = time.time() 229 | for x, target in train_loader: 230 | 231 | optimizer.zero_grad() 232 | 233 | c,h,w = x.size()[-3:] 234 | x = x.view(-1,c,h,w).cuda() 235 | target = target.cuda() 236 | 237 | batch = int(x.size(0) / 4) 238 | nor_index = (torch.arange(4*batch) % 4 == 0).cuda() 239 | aug_index = (torch.arange(4*batch) % 4 != 0).cuda() 240 | 241 | output, s_feat, _ = s_model(x, bb_grad=True) 242 | log_nor_output = F.log_softmax(output[nor_index] / args.kd_T, dim=1) 243 | log_aug_output = F.log_softmax(output[aug_index] / args.tf_T, dim=1) 244 | with torch.no_grad(): 245 | knowledge, t_feat, _ = t_model(x) 246 | nor_knowledge = F.softmax(knowledge[nor_index] / args.kd_T, dim=1) 247 | aug_knowledge = F.softmax(knowledge[aug_index] / args.tf_T, dim=1) 248 | 249 | # error level ranking 250 | aug_target = target.unsqueeze(1).expand(-1,3).contiguous().view(-1).long().cuda() 251 | rank = torch.argsort(aug_knowledge, dim=1, descending=True) 252 | rank = torch.argmax(torch.eq(rank, aug_target.unsqueeze(1)).long(), dim=1) # groundtruth label's rank 253 | index = torch.argsort(rank) 254 | tmp = torch.nonzero(rank, as_tuple=True)[0] 255 | wrong_num = tmp.numel() 256 | correct_num = 3*batch - wrong_num 257 | wrong_keep = int(wrong_num * args.ratio_tf) 258 | index = index[:correct_num+wrong_keep] 259 | distill_index_tf = torch.sort(index)[0] 260 | 261 | s_nor_feat = s_feat[nor_index] 262 | s_aug_feat = s_feat[aug_index] 263 | s_nor_feat = s_nor_feat.unsqueeze(2).expand(-1,-1,3*batch).transpose(0,2) 264 | s_aug_feat = s_aug_feat.unsqueeze(2).expand(-1,-1,1*batch) 265 | s_simi = F.cosine_similarity(s_aug_feat, s_nor_feat, dim=1) 266 | 267 | t_nor_feat = t_feat[nor_index] 268 | t_aug_feat = t_feat[aug_index] 269 | t_nor_feat = t_nor_feat.unsqueeze(2).expand(-1,-1,3*batch).transpose(0,2) 270 | t_aug_feat = t_aug_feat.unsqueeze(2).expand(-1,-1,1*batch) 271 | t_simi = F.cosine_similarity(t_aug_feat, t_nor_feat, dim=1) 272 | 273 | t_simi = t_simi.detach() 274 | aug_target = torch.arange(batch).unsqueeze(1).expand(-1,3).contiguous().view(-1).long().cuda() 275 | rank = torch.argsort(t_simi, dim=1, descending=True) 276 | rank = torch.argmax(torch.eq(rank, aug_target.unsqueeze(1)).long(), dim=1) # groundtruth label's rank 277 | index = torch.argsort(rank) 278 | tmp = torch.nonzero(rank, as_tuple=True)[0] 279 | wrong_num = tmp.numel() 280 | correct_num = 3*batch - wrong_num 281 | wrong_keep = int(wrong_num * args.ratio_ss) 282 | index = index[:correct_num+wrong_keep] 283 | distill_index_ss = torch.sort(index)[0] 284 | 285 | log_simi = F.log_softmax(s_simi / args.ss_T, dim=1) 286 | simi_knowledge = F.softmax(t_simi / args.ss_T, dim=1) 287 | 288 | loss1 = F.cross_entropy(output[nor_index], target) 289 | loss2 = F.kl_div(log_nor_output, nor_knowledge, reduction='batchmean') * args.kd_T * args.kd_T 290 | loss3 = F.kl_div(log_aug_output[distill_index_tf], aug_knowledge[distill_index_tf], \ 291 | reduction='batchmean') * args.tf_T * args.tf_T 292 | loss4 = F.kl_div(log_simi[distill_index_ss], simi_knowledge[distill_index_ss], \ 293 | reduction='batchmean') * args.ss_T * args.ss_T 294 | 295 | loss = args.ce_weight * loss1 + args.kd_weight * loss2 + args.tf_weight * loss3 + args.ss_weight * loss4 296 | 297 | loss.backward() 298 | optimizer.step() 299 | 300 | cls_batch_acc = accuracy(output[nor_index], target, topk=(1,))[0] 301 | ssp_batch_acc = accuracy(s_simi, aug_target, topk=(1,))[0] 302 | loss1_record.update(loss1.item(), batch) 303 | loss2_record.update(loss2.item(), batch) 304 | loss3_record.update(loss3.item(), len(distill_index_tf)) 305 | loss4_record.update(loss4.item(), len(distill_index_ss)) 306 | cls_acc_record.update(cls_batch_acc.item(), batch) 307 | ssp_acc_record.update(ssp_batch_acc.item(), 3*batch) 308 | 309 | logger.add_scalar('train/ce_loss', loss1_record.avg, epoch+1) 310 | logger.add_scalar('train/kd_loss', loss2_record.avg, epoch+1) 311 | logger.add_scalar('train/tf_loss', loss3_record.avg, epoch+1) 312 | logger.add_scalar('train/ss_loss', loss4_record.avg, epoch+1) 313 | logger.add_scalar('train/cls_acc', cls_acc_record.avg, epoch+1) 314 | logger.add_scalar('train/ss_acc', ssp_acc_record.avg, epoch+1) 315 | 316 | run_time = time.time() - start 317 | info = 'student_train_Epoch:{:03d}/{:03d}\t run_time:{:.3f}\t ce_loss:{:.3f}\t kd_loss:{:.3f}\t cls_acc:{:.2f}'.format( 318 | epoch+1, args.epoch, run_time, loss1_record.avg, loss2_record.avg, cls_acc_record.avg) 319 | print(info) 320 | 321 | # cls val 322 | s_model.eval() 323 | acc_record = AverageMeter() 324 | loss_record = AverageMeter() 325 | start = time.time() 326 | for x, target in val_loader: 327 | 328 | x = x[:,0,:,:,:].cuda() 329 | target = target.cuda() 330 | with torch.no_grad(): 331 | output, _, feat = s_model(x) 332 | loss = F.cross_entropy(output, target) 333 | 334 | batch_acc = accuracy(output, target, topk=(1,))[0] 335 | acc_record.update(batch_acc.item(), x.size(0)) 336 | loss_record.update(loss.item(), x.size(0)) 337 | 338 | run_time = time.time() - start 339 | logger.add_scalar('val/ce_loss', loss_record.avg, epoch+1) 340 | logger.add_scalar('val/cls_acc', acc_record.avg, epoch+1) 341 | 342 | info = 'student_test_Epoch:{:03d}/{:03d}\t run_time:{:.2f}\t cls_acc:{:.2f}\n'.format( 343 | epoch+1, args.epoch, run_time, acc_record.avg) 344 | print(info) 345 | 346 | if acc_record.avg > best_acc: 347 | best_acc = acc_record.avg 348 | state_dict = dict(epoch=epoch+1, state_dict=s_model.state_dict(), best_acc=best_acc) 349 | name = osp.join(exp_path, 'ckpt/student_best.pth') 350 | os.makedirs(osp.dirname(name), exist_ok=True) 351 | torch.save(state_dict, name) 352 | 353 | scheduler.step() 354 | 355 | -------------------------------------------------------------------------------- /teacher.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import argparse 4 | import time 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from torch.optim.lr_scheduler import MultiStepLR 12 | from torch.utils.data import DataLoader 13 | 14 | import torchvision.transforms as transforms 15 | from torchvision.datasets import CIFAR100 16 | from tensorboardX import SummaryWriter 17 | 18 | from utils import AverageMeter, accuracy 19 | from models import model_dict 20 | 21 | torch.backends.cudnn.benchmark = True 22 | 23 | parser = argparse.ArgumentParser(description='train teacher network.') 24 | parser.add_argument('--epoch', type=int, default=240) 25 | parser.add_argument('--batch-size', type=int, default=64) 26 | 27 | parser.add_argument('--lr', type=float, default=0.05) 28 | parser.add_argument('--momentum', type=float, default=0.9) 29 | parser.add_argument('--weight-decay', type=float, default=5e-4) 30 | parser.add_argument('--gamma', type=float, default=0.1) 31 | parser.add_argument('--milestones', type=int, nargs='+', default=[150,180,210]) 32 | 33 | parser.add_argument('--save-interval', type=int, default=40) 34 | parser.add_argument('--arch', type=str) 35 | parser.add_argument('--seed', type=int, default=0) 36 | parser.add_argument('--gpu-id', type=int, default=0) 37 | 38 | args = parser.parse_args() 39 | torch.manual_seed(args.seed) 40 | torch.cuda.manual_seed(args.seed) 41 | np.random.seed(args.seed) 42 | 43 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) 44 | 45 | exp_name = 'teacher_{}_seed{}'.format(args.arch, args.seed) 46 | exp_path = './experiments/{}'.format(exp_name) 47 | os.makedirs(exp_path, exist_ok=True) 48 | 49 | transform_train = transforms.Compose([ 50 | transforms.RandomCrop(32, padding=4), 51 | transforms.RandomHorizontalFlip(), 52 | transforms.ToTensor(), 53 | transforms.Normalize(mean=[0.5071, 0.4866, 0.4409], std=[0.2675, 0.2565, 0.2761]), 54 | ]) 55 | transform_test = transforms.Compose([ 56 | transforms.ToTensor(), 57 | transforms.Normalize(mean=[0.5071, 0.4866, 0.4409], std=[0.2675, 0.2565, 0.2761]), 58 | ]) 59 | 60 | trainset = CIFAR100('./data', train=True, transform=transform_train, download=True) 61 | valset = CIFAR100('./data', train=False, transform=transform_test, download=True) 62 | train_loader = DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=False) 63 | val_loader = DataLoader(valset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=False) 64 | 65 | model = model_dict[args.arch](num_classes=100).cuda() 66 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 67 | scheduler = MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma) 68 | 69 | logger = SummaryWriter(osp.join(exp_path, 'events')) 70 | best_acc = -1 71 | for epoch in range(args.epoch): 72 | 73 | model.train() 74 | loss_record = AverageMeter() 75 | acc_record = AverageMeter() 76 | 77 | start = time.time() 78 | for x, target in train_loader: 79 | 80 | optimizer.zero_grad() 81 | x = x.cuda() 82 | target = target.cuda() 83 | 84 | output = model(x) 85 | loss = F.cross_entropy(output, target) 86 | 87 | loss.backward() 88 | optimizer.step() 89 | 90 | batch_acc = accuracy(output, target, topk=(1,))[0] 91 | loss_record.update(loss.item(), x.size(0)) 92 | acc_record.update(batch_acc.item(), x.size(0)) 93 | 94 | logger.add_scalar('train/cls_loss', loss_record.avg, epoch+1) 95 | logger.add_scalar('train/cls_acc', acc_record.avg, epoch+1) 96 | 97 | run_time = time.time() - start 98 | 99 | info = 'train_Epoch:{:03d}/{:03d}\t run_time:{:.3f}\t cls_loss:{:.3f}\t cls_acc:{:.2f}\t'.format( 100 | epoch+1, args.epoch, run_time, loss_record.avg, acc_record.avg) 101 | print(info) 102 | 103 | model.eval() 104 | acc_record = AverageMeter() 105 | loss_record = AverageMeter() 106 | start = time.time() 107 | for x, target in val_loader: 108 | 109 | x = x.cuda() 110 | target = target.cuda() 111 | with torch.no_grad(): 112 | output = model(x) 113 | loss = F.cross_entropy(output, target) 114 | 115 | batch_acc = accuracy(output, target, topk=(1,))[0] 116 | loss_record.update(loss.item(), x.size(0)) 117 | acc_record.update(batch_acc.item(), x.size(0)) 118 | 119 | run_time = time.time() - start 120 | 121 | logger.add_scalar('val/cls_loss', loss_record.avg, epoch+1) 122 | logger.add_scalar('val/cls_acc', acc_record.avg, epoch+1) 123 | 124 | info = 'test_Epoch:{:03d}/{:03d}\t run_time:{:.2f}\t cls_loss:{:.3f}\t cls_acc:{:.2f}\n'.format( 125 | epoch+1, args.epoch, run_time, loss_record.avg, acc_record.avg) 126 | print(info) 127 | 128 | scheduler.step() 129 | 130 | # save checkpoint 131 | if (epoch+1) in args.milestones or epoch+1==args.epoch or (epoch+1)%args.save_interval==0: 132 | state_dict = dict(epoch=epoch+1, state_dict=model.state_dict(), acc=acc_record.avg) 133 | name = osp.join(exp_path, 'ckpt/{:03d}.pth'.format(epoch+1)) 134 | os.makedirs(osp.dirname(name), exist_ok=True) 135 | torch.save(state_dict, name) 136 | 137 | # save best 138 | if acc_record.avg > best_acc: 139 | state_dict = dict(epoch=epoch+1, state_dict=model.state_dict(), acc=acc_record.avg) 140 | name = osp.join(exp_path, 'ckpt/best.pth') 141 | os.makedirs(osp.dirname(name), exist_ok=True) 142 | torch.save(state_dict, name) 143 | best_acc = acc_record.avg 144 | 145 | print('best_acc: {:.2f}'.format(best_acc)) 146 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | 5 | import torch 6 | from torch.nn import init 7 | 8 | class AverageMeter(object): 9 | """Computes and stores the average and current value""" 10 | def __init__(self): 11 | self.reset() 12 | 13 | def reset(self): 14 | self.count = 0 15 | self.sum = 0.0 16 | self.val = 0.0 17 | self.avg = 0.0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | 25 | def accuracy(output, target, topk=(1,)): 26 | """Computes the precision@k for the specified values of k""" 27 | maxk = max(topk) 28 | batch_size = target.size(0) 29 | 30 | _, pred = output.topk(maxk, 1, True, True) 31 | pred = pred.t() 32 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 33 | 34 | res = [] 35 | for k in topk: 36 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 37 | res.append(correct_k.mul_(100.0 / batch_size)) 38 | return res 39 | 40 | def norm(x): 41 | 42 | n = np.linalg.norm(x) 43 | return x / n 44 | -------------------------------------------------------------------------------- /wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class wrapper(nn.Module): 6 | 7 | def __init__(self, module): 8 | 9 | super(wrapper, self).__init__() 10 | 11 | self.backbone = module 12 | feat_dim = list(module.children())[-1].in_features 13 | self.proj_head = nn.Sequential( 14 | nn.Linear(feat_dim, feat_dim), 15 | nn.ReLU(inplace=True), 16 | nn.Linear(feat_dim, feat_dim) 17 | ) 18 | 19 | def forward(self, x, bb_grad=True): 20 | 21 | feats, out = self.backbone(x, is_feat=True) 22 | feat = feats[-1].view(feats[-1].size(0), -1) 23 | if not bb_grad: 24 | feat = feat.detach() 25 | 26 | return out, self.proj_head(feat), feat 27 | 28 | --------------------------------------------------------------------------------