├── dataset ├── __init__.py └── tinyimagenet.py ├── figures ├── Fig1.png ├── Fig2.png ├── Fig3.png └── Fig4.png ├── models ├── __init__.py ├── student │ ├── __init__.py │ ├── fcnet.py │ ├── lenet.py │ ├── mynet.py │ └── resnet_s.py ├── teacher │ ├── __init__.py │ ├── alexnet.py │ ├── hint.py │ ├── vgg.py │ ├── wide_resnet.py │ ├── googlenet.py │ ├── dpn.py │ ├── resnext.py │ ├── densenet.py │ ├── preact_resnet.py │ ├── resnet.py │ └── resnet20.py └── embedding.py ├── utils ├── __init__.py ├── feature_shape.py ├── model_init.py └── averagemeter.py ├── LICENSE ├── .gitignore ├── README.md ├── metric └── loss.py ├── train_resnet.py ├── FitNet_distill_hook.ipynb ├── RKD.ipynb ├── group_Hint_only.ipynb ├── DML_3S.ipynb ├── multi_teacher_avg_distill.ipynb └── adaptive_FitNet_teacher-level.ipynb /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .tinyimagenet import * -------------------------------------------------------------------------------- /figures/Fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FLHonker/AMTML-KD-code/HEAD/figures/Fig1.png -------------------------------------------------------------------------------- /figures/Fig2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FLHonker/AMTML-KD-code/HEAD/figures/Fig2.png -------------------------------------------------------------------------------- /figures/Fig3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FLHonker/AMTML-KD-code/HEAD/figures/Fig3.png -------------------------------------------------------------------------------- /figures/Fig4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FLHonker/AMTML-KD-code/HEAD/figures/Fig4.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .teacher import * 2 | from .student import * 3 | from .embedding import * -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .averagemeter import * 2 | from .model_init import * 3 | from .feature_shape import * -------------------------------------------------------------------------------- /models/student/__init__.py: -------------------------------------------------------------------------------- 1 | from .lenet import * 2 | from .mynet import * 3 | from .fcnet import * 4 | from .resnet_s import * -------------------------------------------------------------------------------- /models/teacher/__init__.py: -------------------------------------------------------------------------------- 1 | from .alexnet import * 2 | from .vgg import * 3 | from .dpn import * 4 | from .densenet import * 5 | from .googlenet import * 6 | from .resnet import * 7 | from .resnext import * 8 | from .preact_resnet import * 9 | from .hint import * 10 | from .resnet20 import * 11 | from .wide_resnet import * -------------------------------------------------------------------------------- /utils/feature_shape.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def get_feature_shape(model): 4 | model.cuda() 5 | input = torch.randn(128, 3, 224, 224).cuda() 6 | b1, b2, b3, pool, out = model(input) 7 | feat_maps = [b1, b2, b3, pool, out] 8 | feat_shapes = [e.size() for e in feat_maps] 9 | print(feat_shapes) 10 | return feat_shapes -------------------------------------------------------------------------------- /utils/model_init.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | # weights init 4 | def weights_init_normal(m): 5 | classname = m.__class__.__name__ 6 | if classname.find('Conv') != -1: 7 | nn.init.normal_(m.weight.data, 0.0, 0.02) 8 | elif classname.find("BatchNorm2d") != -1: 9 | nn.init.normal_(m.weight.data, 1.0, 0.02) 10 | nn.init.constant_(m.bias.data, 0.0) 11 | elif classname.find('linear') != -1: 12 | nn.init.normal_(m.weight.data, 1.0, 0.02) 13 | nn.init.constant_(m.bias.data, 0.0) -------------------------------------------------------------------------------- /models/student/fcnet.py: -------------------------------------------------------------------------------- 1 | # --------------------- 2 | # Student Nets 3 | # --------------------- 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | # Linear Net-1 9 | class fcNet(nn.Module): 10 | def __init__(self): 11 | super(StudentNet1, self).__init__() 12 | self.fc1 = nn.Linear(32 * 32 * 3, 1200) 13 | self.fc2 = nn.Linear(1200, 1200) 14 | self.fc3 = nn.Linear(1200, 10) 15 | 16 | def forward(self, x): 17 | x = x.view(-1, 32 * 32 * 3) 18 | x = F.relu(self.fc1(x)) 19 | x = F.dropout(x, p=0.8, training=self.training) 20 | x = F.relu(self.fc2(x)) 21 | x = F.dropout(x, p=0.8, training=self.training) 22 | x = self.fc3(x) 23 | return x 24 | -------------------------------------------------------------------------------- /models/student/lenet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | # LeNet-5 5 | class LeNet5(nn.Module): 6 | def __init__(self, num_classes=10): 7 | super(LeNet5, self).__init__() 8 | self.conv1 = nn.Conv2d(3, 6, 5) 9 | self.pool = nn.MaxPool2d(2, 2) 10 | self.conv2 = nn.Conv2d(6, 16, 5) 11 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 12 | self.fc2 = nn.Linear(120, 84) 13 | self.fc3 = nn.Linear(84, num_classes) 14 | 15 | def forward(self, x): 16 | x = self.pool(F.relu(self.conv1(x))) 17 | x = self.pool(F.relu(self.conv2(x))) 18 | x = x.view(-1, 16 * 5 * 5) 19 | x = F.relu(self.fc1(x)) 20 | x = F.relu(self.fc2(x)) 21 | x = self.fc3(x) 22 | return x 23 | -------------------------------------------------------------------------------- /models/embedding.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | __all__ = ["LinearEmbedding"] 5 | 6 | 7 | class LinearEmbedding(nn.Module): 8 | def __init__(self, base, output_size=64, embedding_size=64, normalize=True): 9 | super(LinearEmbedding, self).__init__() 10 | self.base = base 11 | self.linear = nn.Linear(output_size, embedding_size) 12 | self.normalize = normalize 13 | 14 | def forward(self, x, get_ha=True): 15 | if get_ha: 16 | b1, b2, b3, pool, out = self.base(x, True) 17 | else: 18 | pool = self.base(x) 19 | 20 | pool = pool.view(x.size(0), -1) 21 | embedding = self.linear(pool) 22 | 23 | if self.normalize: 24 | embedding = F.normalize(embedding, p=2, dim=1) 25 | 26 | if get_ha: 27 | return b1, b2, b3, pool, embedding 28 | 29 | return embedding 30 | -------------------------------------------------------------------------------- /models/student/mynet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | # myNet 6 | class myNet(nn.Module): 7 | def __init__(self, num_classes=10): 8 | super(myNet, self).__init__() 9 | self.conv1 = nn.Conv2d(3, 16, 5) 10 | self.conv2 = nn.Conv2d(16, 32, 3, padding=1) 11 | self.conv3 = nn.Conv2d(32, 64, 5) 12 | self.conv4 = nn.Conv2d(64, 16, 5, padding=2) 13 | self.pool = nn.MaxPool2d(2, 2) 14 | 15 | self.fc1 = nn.Linear(16 * 5 * 5, 120) 16 | self.fc2 = nn.Linear(120, 84) 17 | self.fc3 = nn.Linear(84, num_classes) 18 | 19 | def forward(self, x): 20 | x = F.relu(self.conv1(x)) 21 | x = self.pool(F.relu(self.conv2(x))) 22 | x = F.relu(self.conv3(x)) 23 | x = self.pool(F.relu(self.conv4(x))) 24 | x = x.view(-1, 16 * 5 * 5) 25 | x = F.relu(self.fc1(x)) 26 | x = F.relu(self.fc2(x)) 27 | x = self.fc3(x) 28 | return x 29 | -------------------------------------------------------------------------------- /utils/averagemeter.py: -------------------------------------------------------------------------------- 1 | # AverageMeter 2 | class AverageMeter(object): 3 | """Computes and stores the average and current value""" 4 | def __init__(self): 5 | self.reset() 6 | 7 | def reset(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def update(self, val, n=1): 14 | self.val = val 15 | self.sum += val * n 16 | self.count += n 17 | self.avg = self.sum / self.count 18 | 19 | # compute accuracy 20 | def accuracy(output, target, topk=(1,)): 21 | """Computes the precision@k for the specified values of k""" 22 | maxk = max(topk) 23 | batch_size = target.size(0) 24 | 25 | _, pred = output.topk(maxk, 1, True, True) 26 | pred = pred.t() 27 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 28 | 29 | res = [] 30 | for k in topk: 31 | correct_k = correct[:k].view(-1).float().sum(0) 32 | res.append(correct_k.mul_(100.0 / batch_size)) 33 | return res 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Frank 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/teacher/alexnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # AlexNet model 5 | class AlexNet(nn.Module): 6 | 7 | def __init__(self, num_classes=10): 8 | super(AlexNet, self).__init__() 9 | self.features = nn.Sequential( 10 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), 11 | nn.ReLU(inplace=True), 12 | nn.MaxPool2d(kernel_size=2), 13 | nn.Conv2d(64, 192, kernel_size=3, padding=1), 14 | nn.ReLU(inplace=True), 15 | nn.MaxPool2d(kernel_size=2), 16 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 19 | nn.ReLU(inplace=True), 20 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 21 | nn.ReLU(inplace=True), 22 | nn.MaxPool2d(kernel_size=2), 23 | ) 24 | self.classifier = nn.Sequential( 25 | nn.Dropout(), 26 | nn.Linear(256 * 2 * 2, 4096), 27 | nn.ReLU(inplace=True), 28 | nn.Dropout(), 29 | nn.Linear(4096, 4096), 30 | nn.ReLU(inplace=True), 31 | nn.Linear(4096, num_classes), 32 | ) 33 | 34 | 35 | def forward(self, x): 36 | x = self.features(x) 37 | x = x.view(x.size(0), 256 * 2 * 2) 38 | x = self.classifier(x) 39 | return x -------------------------------------------------------------------------------- /models/teacher/hint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # Hint Nets 5 | # settings 6 | hint_cfg = { 7 | 'Hint5':[256, 512, 'M', 1024, 1024, 'M', 512], 8 | 'Hint7':[256, 512, 'M', 1024, 1024, 'M', 2048, 'M', 1024, 512, 'M'], 9 | 'Hint9':[128, 256, 'M', 512, 1024, 'M', 2048, 1024, 'M', 512, 512, 'M'], 10 | } 11 | 12 | # model 13 | class Hint(nn.Module): 14 | def __init__(self, hint_name, num_classes=10): 15 | super(Hint, self).__init__() 16 | self.features = self._make_layers(hint_cfg[hint_name]) 17 | # self.classifier = nn.Linear(2048, num_classes) 18 | self.classifier = nn.Sequential( 19 | nn.Linear(32768, 4096), 20 | nn.ReLU(inplace=True), 21 | nn.Dropout(), 22 | nn.Linear(4096, 4096), 23 | nn.ReLU(inplace=True), 24 | nn.Dropout(), 25 | nn.Linear(4096, num_classes) 26 | ) 27 | 28 | def forward(self, x): 29 | x = self.features(x) 30 | x = x.view(x.size(0), -1) 31 | x = self.classifier(x) 32 | return x 33 | 34 | def _make_layers(self, cfg): 35 | layers = [] 36 | in_channels = 3 37 | for h in cfg: 38 | if h == 'M': 39 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 40 | else : 41 | layers += [nn.Conv2d(in_channels, h, kernel_size=3, padding=1), 42 | nn.BatchNorm2d(h), 43 | nn.ReLU(inplace=True)] 44 | in_channels = h 45 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 46 | return nn.Sequential(*layers) 47 | 48 | # Hints 49 | def Hint5(num_classes=10): 50 | return Hint('Hint5', num_classes) 51 | 52 | def Hint7(num_classes=10): 53 | return Hint('Hint7', num_classes) 54 | 55 | def Hint9(num_classes=10): 56 | return Hint('Hint9', num_classes) 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | <<<<<<< HEAD 6 | *.pth 7 | *.pth.tar 8 | ======= 9 | >>>>>>> bd4b15567d0a4fd4f1d949901ec136fe58c2a29f 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | <<<<<<< HEAD 61 | #*.log 62 | ======= 63 | *.log 64 | >>>>>>> bd4b15567d0a4fd4f1d949901ec136fe58c2a29f 65 | local_settings.py 66 | db.sqlite3 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Environments 94 | .env 95 | .venv 96 | env/ 97 | venv/ 98 | ENV/ 99 | env.bak/ 100 | venv.bak/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adaptive Multi-Teacher Multi-level Knowledge Distillation(AMTML-KD) 2 | 3 | Paper has been accepted by Neurocomputing 415(2020): 106–113. 4 | 5 | Authors: [Yuang Liu](https://flhonker.github.io/), Wei Zhang and Jun Wang. 6 | 7 | Links: [ [pdf](https://arxiv.org/pdf/2012.00573) ] [ [code](https://github.com/FLHonker/AMTML-KD-code) ] 8 | 9 | ## Requirements 10 | 11 | * PyTorch >= 1.0.0 12 | * Jupyter 13 | * visdom 14 | 15 | ## Introduction 16 | 17 | Knowledge distillation (KD) is an effective learning paradigm for improving the performance of light-weight student networks by utilizing additional supervision knowledge distilled from teacher networks. Most pioneering studies either learn from only a single teacher in their distillation learning methods, neglecting the potential that a student can learn from multiple teachers simultaneously, or simply treat each teacher to be equally important, unable to reveal the different importance of teachers for specific examples. To bridge this gap, we propose a novel adaptive multi-teacher multi-level knowledge distillation learning framework (**AMTML-KD**), which consists two novel insights: (i) associating each teacher with a latent representation to adaptively learn instance-level teacher importance weights which are leveraged for acquiring integrated soft-targets (high-level knowledge) and (ii) enabling the intermediate-level hints (intermediate-level knowledge) to be gathered from multiple teachers by the proposed multi-group hint strategy. As such, a student model can learn multi-level knowledge from multiple teachers through AMTML-KD. Extensive results on publicly available datasets demonstrate the proposed learning framework ensures student to achieve improved performance than strong competitors. 18 | 19 | ![adaptive](./figures/Fig1.png) 20 | 21 | ![framework](./figures/Fig2.png) 22 | 23 | ![multi-teacher](figures/Fig3.png) 24 | 25 | ![examples](figures/Fig4.png) 26 | 27 | ## Citation 28 | 29 | ``` 30 | @article{LIU2020106, 31 | title = {Adaptive multi-teacher multi-level knowledge distillation}, 32 | author = {Yuang Liu and Wei Zhang and Jun Wang}, 33 | journal = {Neurocomputing}, 34 | volume = {415}, 35 | pages = {106 -- 113}, 36 | year = {2020}, 37 | issn = {0925 -- 2312}, 38 | } 39 | ``` 40 | -------------------------------------------------------------------------------- /models/teacher/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # VGGs 5 | # settings 6 | cfg = { 7 | 'VGG11':[64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 8 | 'VGG13':[64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 9 | 'VGG16':[64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 10 | 'VGG19':[64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 11 | } 12 | 13 | # model 14 | class VGG(nn.Module): 15 | def __init__(self, vgg_name, num_classes=10, name=None): 16 | super(VGG, self).__init__() 17 | self.model_name = name 18 | 19 | self.features = self._make_layers(cfg[vgg_name]) 20 | # self.classifier = nn.Linear(512, num_classes) 21 | self.classifier = nn.Sequential( 22 | nn.Linear(512, 4096), 23 | nn.ReLU(inplace=True), 24 | nn.Dropout(), 25 | nn.Linear(4096, 4096), 26 | nn.ReLU(inplace=True), 27 | nn.Dropout(), 28 | nn.Linear(4096, num_classes) 29 | ) 30 | 31 | def forward(self, x): 32 | pool = self.features(x) 33 | out = pool.view(pool.size(0), -1) 34 | out = self.classifier(out) 35 | return pool, out 36 | 37 | def _make_layers(self, cfg): 38 | layers = [] 39 | in_channels = 3 40 | for h in cfg: 41 | if h == 'M': 42 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 43 | else : 44 | layers += [nn.Conv2d(in_channels, h, kernel_size=3, padding=1), 45 | nn.BatchNorm2d(h), 46 | nn.ReLU(inplace=True)] 47 | in_channels = h 48 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 49 | return nn.Sequential(*layers) 50 | 51 | # VGGs 52 | def VGG11(num_classes=10): 53 | return VGG('VGG11', num_classes, name='VGG11') 54 | 55 | def VGG13(num_classes=10): 56 | return VGG('VGG13', num_classes, name='VGG13') 57 | 58 | def VGG16(num_classes=10): 59 | return VGG('VGG16', num_classes, name='VGG16') 60 | 61 | def VGG19(num_classes=10): 62 | return VGG('VGG19', num_classes, name='VGG19') -------------------------------------------------------------------------------- /models/teacher/wide_resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | wide resnet for cifar in pytorch 3 | Reference: 4 | [1] S. Zagoruyko and N. Komodakis. Wide residual networks. In BMVC, 2016. 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | import math 9 | from .resnet import BasicBlock 10 | 11 | 12 | class Wide_ResNet(nn.Module): 13 | 14 | def __init__(self, block, layers, wfactor, num_classes=10): 15 | super(Wide_ResNet_Cifar, self).__init__() 16 | self.inplanes = 16 17 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(16) 19 | self.relu = nn.ReLU(inplace=True) 20 | self.layer1 = self._make_layer(block, 16*wfactor, layers[0]) 21 | self.layer2 = self._make_layer(block, 32*wfactor, layers[1], stride=2) 22 | self.layer3 = self._make_layer(block, 64*wfactor, layers[2], stride=2) 23 | self.avgpool = nn.AvgPool2d(8, stride=1) 24 | self.fc = nn.Linear(64*block.expansion*wfactor, num_classes) 25 | 26 | for m in self.modules(): 27 | if isinstance(m, nn.Conv2d): 28 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 29 | m.weight.data.normal_(0, math.sqrt(2. / n)) 30 | elif isinstance(m, nn.BatchNorm2d): 31 | m.weight.data.fill_(1) 32 | m.bias.data.zero_() 33 | 34 | def _make_layer(self, block, planes, blocks, stride=1): 35 | downsample = None 36 | if stride != 1 or self.inplanes != planes * block.expansion: 37 | downsample = nn.Sequential( 38 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 39 | nn.BatchNorm2d(planes * block.expansion) 40 | ) 41 | 42 | layers = [] 43 | layers.append(block(self.inplanes, planes, stride, downsample)) 44 | self.inplanes = planes * block.expansion 45 | for _ in range(1, blocks): 46 | layers.append(block(self.inplanes, planes)) 47 | 48 | return nn.Sequential(*layers) 49 | 50 | def forward(self, x): 51 | x = self.conv1(x) 52 | x = self.bn1(x) 53 | x = self.relu(x) 54 | 55 | x = self.layer1(x) 56 | x = self.layer2(x) 57 | x = self.layer3(x) 58 | 59 | x = self.avgpool(x) 60 | x = x.view(x.size(0), -1) 61 | x = self.fc(x) 62 | 63 | return x 64 | 65 | 66 | def wide_resnet_cifar(depth, width, **kwargs): 67 | assert (depth - 2) % 6 == 0 68 | n = (depth - 2) / 6 69 | return Wide_ResNet(BasicBlock, [n, n, n], width, **kwargs) 70 | 71 | 72 | if __name__=='__main__': 73 | net = wide_resnet_cifar(20, 10) 74 | y = net(torch.randn(1, 3, 32, 32)) 75 | print(isinstance(net, Wide_ResNet_Cifar)) 76 | print(y.size()) 77 | -------------------------------------------------------------------------------- /metric/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # __all__ = ['FitNet', 'AttentionTransfer', 'logits_distillation_loss'] 7 | 8 | def pdist(e, squared=False, eps=1e-12): 9 | e_square = e.pow(2).sum(dim=1) 10 | prod = e @ e.t() 11 | res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps) 12 | 13 | if not squared: 14 | res = res.sqrt() 15 | 16 | res = res.clone() 17 | res[range(len(e)), range(len(e))] = 0 18 | return res 19 | 20 | 21 | class FitNet(nn.Module): 22 | def __init__(self, in_feature, out_feature): 23 | super().__init__() 24 | self.in_feature = in_feature 25 | self.out_feature = out_feature 26 | 27 | self.transform = nn.Conv2d(in_feature, out_feature, 1, bias=False) 28 | self.transform.weight.data.uniform_(-0.005, 0.005) 29 | 30 | def forward(self, student, teacher): 31 | if student.dim() == 2: 32 | student = student.unsqueeze(2).unsqueeze(3) 33 | teacher = teacher.unsqueeze(2).unsqueeze(3) 34 | student = F.normalize(student) 35 | teacher = F.normalize(teacher) 36 | 37 | return (self.transform(student) - teacher).pow(2).mean() 38 | 39 | 40 | class AttentionTransfer(nn.Module): 41 | def forward(self, student, teacher): 42 | s_attention = F.normalize(student.pow(2).mean(1).view(student.size(0), -1)) 43 | 44 | with torch.no_grad(): 45 | t_attention = F.normalize(teacher.pow(2).mean(1).view(teacher.size(0), -1)) 46 | 47 | return (s_attention - t_attention).pow(2).mean() 48 | 49 | 50 | class RKdAngle(nn.Module): 51 | def forward(self, student, teacher): 52 | # N x C 53 | # N x N x C 54 | 55 | with torch.no_grad(): 56 | td = (teacher.unsqueeze(0) - teacher.unsqueeze(1)) 57 | norm_td = F.normalize(td, p=2, dim=2) 58 | t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1) 59 | 60 | sd = (student.unsqueeze(0) - student.unsqueeze(1)) 61 | norm_sd = F.normalize(sd, p=2, dim=2) 62 | s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1) 63 | 64 | loss = F.smooth_l1_loss(s_angle, t_angle, reduction='elementwise_mean') 65 | return loss 66 | 67 | 68 | class RkdDistance(nn.Module): 69 | def forward(self, student, teacher): 70 | with torch.no_grad(): 71 | t_d = pdist(teacher, squared=False) 72 | mean_td = t_d[t_d>0].mean() 73 | t_d = t_d / mean_td 74 | 75 | d = pdist(student, squared=False) 76 | mean_d = d[d>0].mean() 77 | d = d / mean_d 78 | 79 | loss = F.smooth_l1_loss(d, t_d, reduction='elementwise_mean') 80 | return loss 81 | 82 | class HardDarkRank(nn.Module): 83 | def __init__(self, alpha=3, beta=3, permute_len=4): 84 | super().__init__() 85 | self.alpha = alpha 86 | self.beta = beta 87 | self.permute_len = permute_len 88 | 89 | def forward(self, student, teacher): 90 | score_teacher = -1 * self.alpha * pdist(teacher, squared=False).pow(self.beta) 91 | score_student = -1 * self.alpha * pdist(student, squared=False).pow(self.beta) 92 | 93 | permute_idx = score_teacher.sort(dim=1, descending=True)[1][:, 1:(self.permute_len+1)] 94 | ordered_student = torch.gather(score_student, 1, permute_idx) 95 | 96 | log_prob = (ordered_student - torch.stack([torch.logsumexp(ordered_student[:, i:], dim=1) for i in range(permute_idx.size(1))], dim=1)).sum(dim=1) 97 | loss = (-1 * log_prob).mean() 98 | 99 | return loss -------------------------------------------------------------------------------- /models/teacher/googlenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # Inception 5 | class Inception(nn.Module): 6 | def __init__(self, in_planes, kernel_1_x, kernel_3_in, kernel_3_x, kernel_5_in, kernel_5_x, pool_planes): 7 | super(Inception, self).__init__() 8 | # 1x1 conv branch 9 | self.b1 = nn.Sequential( 10 | nn.Conv2d(in_planes, kernel_1_x, kernel_size=1), 11 | nn.BatchNorm2d(kernel_1_x), 12 | nn.ReLU(inplace=True), # can save memery,but override the old variants 13 | ) 14 | 15 | # 1x1 conv -> 3x3 conv branch 16 | self.b2 = nn.Sequential( 17 | nn.Conv2d(in_planes, kernel_3_in, kernel_size=1), 18 | nn.BatchNorm2d(kernel_3_in), 19 | nn.ReLU(True), 20 | nn.Conv2d(kernel_3_in, kernel_3_x, kernel_size=3, padding=1), 21 | nn.BatchNorm2d(kernel_3_x), 22 | nn.ReLU(True), 23 | ) 24 | 25 | # 1x1 conv -> 5x5 conv branch 26 | self.b3 = nn.Sequential( 27 | nn.Conv2d(in_planes, kernel_5_in, kernel_size=1), 28 | nn.BatchNorm2d(kernel_5_in), 29 | nn.ReLU(True), 30 | nn.Conv2d(kernel_5_in, kernel_5_x, kernel_size=3, padding=1), 31 | nn.BatchNorm2d(kernel_5_x), 32 | nn.ReLU(True), 33 | nn.Conv2d(kernel_5_x, kernel_5_x, kernel_size=3, padding=1), 34 | nn.BatchNorm2d(kernel_5_x), 35 | nn.ReLU(True) 36 | ) 37 | 38 | # 3x3 pool -> 1x1 conv branch 39 | self.b4 = nn.Sequential( 40 | nn.MaxPool2d(3, stride=1, padding=1), 41 | nn.Conv2d(in_planes, pool_planes, kernel_size=1), 42 | nn.BatchNorm2d(pool_planes), 43 | nn.ReLU(True), 44 | ) 45 | 46 | def forward(self, x): 47 | y1 = self.b1(x) 48 | y2 = self.b2(x) 49 | y3 = self.b3(x) 50 | y4 = self.b4(x) 51 | return torch.cat([y1,y2,y3,y4], 1) 52 | 53 | # GoogLeNet 54 | class GoogLeNet(nn.Module): 55 | def __init__(self, num_classes=10): 56 | super(GoogLeNet, self).__init__() 57 | self.pre_layers = nn.Sequential( 58 | nn.Conv2d(3, 192, kernel_size=3, padding=1), 59 | nn.BatchNorm2d(192), 60 | nn.ReLU(True), 61 | ) 62 | 63 | self.a3 = Inception(192, 64, 96, 128, 16, 32, 32) 64 | self.b3 = Inception(256, 128, 128, 192, 32, 96, 64) 65 | 66 | self.max_pool = nn.MaxPool2d(3, stride=2, padding=1) 67 | 68 | self.a4 = Inception(480, 192, 96, 208, 16, 48, 64) 69 | self.b4 = Inception(512, 160, 112, 224, 24, 64, 64) 70 | self.c4 = Inception(512, 128, 128, 256, 24, 64, 64) 71 | self.d4 = Inception(512, 112, 144, 288, 32, 64, 64) 72 | self.e4 = Inception(528, 256, 160, 320, 32, 128, 128) 73 | 74 | self.a5 = Inception(832, 256, 160, 320, 32, 128, 128) 75 | self.b5 = Inception(832, 384, 192, 384, 48, 128, 128) 76 | 77 | self.avgpool = nn.AvgPool2d(8, stride=1) 78 | self.linear = nn.Linear(1024, num_classes) 79 | 80 | def forward(self, x): 81 | x = self.pre_layers(x) 82 | x = self.a3(x) 83 | x = self.b3(x) 84 | x = self.max_pool(x) 85 | x = self.a4(x) 86 | x = self.b4(x) 87 | x = self.c4(x) 88 | x = self.d4(x) 89 | x = self.e4(x) 90 | x = self.max_pool(x) 91 | x = self.a5(x) 92 | x = self.b5(x) 93 | x = self.avgpool(x) 94 | x = x.view(x.size(0), -1) 95 | x = self.linear(x) 96 | return x -------------------------------------------------------------------------------- /models/student/resnet_s.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ResNet student 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.nn.init as init 8 | from torch.autograd import Variable 9 | 10 | def _weights_init(m): 11 | classname = m.__class__.__name__ 12 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 13 | init.kaiming_normal(m.weight) 14 | 15 | class LambdaLayer(nn.Module): 16 | def __init__(self, lambd): 17 | super(LambdaLayer, self).__init__() 18 | self.lambd = lambd 19 | 20 | def forward(self, x): 21 | return self.lambd(x) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, in_planes, planes, stride=1, option='A'): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 32 | self.bn2 = nn.BatchNorm2d(planes) 33 | 34 | self.shortcut = nn.Sequential() 35 | if stride != 1 or in_planes != planes: 36 | if option == 'A': 37 | """ 38 | For CIFAR10 ResNet paper uses option A. 39 | """ 40 | self.shortcut = LambdaLayer(lambda x: F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 41 | elif option == 'B': 42 | self.shortcut = nn.Sequential( 43 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 44 | nn.BatchNorm2d(self.expansion * planes) 45 | ) 46 | 47 | def forward(self, x): 48 | out = F.relu(self.bn1(self.conv1(x))) 49 | out = self.bn2(self.conv2(out)) 50 | out += self.shortcut(x) 51 | out = F.relu(out) 52 | return out 53 | 54 | 55 | class ResNet(nn.Module): 56 | def __init__(self, block, num_blocks, num_classes=10, name=None): 57 | super(ResNet, self).__init__() 58 | self.in_planes = 8 59 | self.model_name = name 60 | self.output_size = 32 61 | 62 | self.conv1 = nn.Conv2d(3, 8, kernel_size=3, stride=1, padding=1, bias=False) 63 | self.bn1 = nn.BatchNorm2d(8) 64 | self.layer1 = self._make_layer(block, 8, num_blocks[0], stride=1) 65 | self.layer2 = self._make_layer(block, 16, num_blocks[1], stride=2) 66 | self.layer3 = self._make_layer(block, 32, num_blocks[2], stride=2) 67 | self.linear = nn.Linear(32, num_classes) 68 | 69 | self.apply(_weights_init) 70 | 71 | def _make_layer(self, block, planes, num_blocks, stride): 72 | strides = [stride] + [1]*(num_blocks-1) 73 | layers = [] 74 | for stride in strides: 75 | layers.append(block(self.in_planes, planes, stride)) 76 | self.in_planes = planes * block.expansion 77 | 78 | return nn.Sequential(*layers) 79 | 80 | def forward(self, x): 81 | out = F.relu(self.bn1(self.conv1(x))) 82 | b1 = self.layer1(out) 83 | b2 = self.layer2(b1) 84 | b3 = self.layer3(b2) 85 | pool = F.avg_pool2d(b3, b3.size()[3]) 86 | out = pool.view(pool.size(0), -1) 87 | out = self.linear(out) 88 | return b1, b2, b3, pool, out 89 | 90 | 91 | def ResNet8(num_classes=10): 92 | return ResNet(BasicBlock, [3, 3, 3], num_classes, name='ResNet8') 93 | 94 | def ResNet15(num_classes=10): 95 | return ResNet(BasicBlock, [5, 5, 5], num_classes, name='ResNet15') 96 | 97 | def ResNet16(num_classes=10): 98 | return ResNet(BasicBlock, [3, 5, 7], num_classes, name='ResNet16') 99 | -------------------------------------------------------------------------------- /models/teacher/dpn.py: -------------------------------------------------------------------------------- 1 | '''Dual Path Networks in PyTorch.''' 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Bottleneck(nn.Module): 8 | def __init__(self, last_planes, in_planes, out_planes, dense_depth, stride, first_layer): 9 | super(Bottleneck, self).__init__() 10 | self.out_planes = out_planes 11 | self.dense_depth = dense_depth 12 | 13 | self.conv1 = nn.Conv2d(last_planes, in_planes, kernel_size=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(in_planes) 15 | self.conv2 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=stride, padding=1, groups=32, bias=False) 16 | self.bn2 = nn.BatchNorm2d(in_planes) 17 | self.conv3 = nn.Conv2d(in_planes, out_planes+dense_depth, kernel_size=1, bias=False) 18 | self.bn3 = nn.BatchNorm2d(out_planes+dense_depth) 19 | 20 | self.shortcut = nn.Sequential() 21 | if first_layer: 22 | self.shortcut = nn.Sequential( 23 | nn.Conv2d(last_planes, out_planes+dense_depth, kernel_size=1, stride=stride, bias=False), 24 | nn.BatchNorm2d(out_planes+dense_depth) 25 | ) 26 | 27 | def forward(self, x): 28 | out = F.relu(self.bn1(self.conv1(x))) 29 | out = F.relu(self.bn2(self.conv2(out))) 30 | out = self.bn3(self.conv3(out)) 31 | x = self.shortcut(x) 32 | d = self.out_planes 33 | out = torch.cat([x[:,:d,:,:]+out[:,:d,:,:], x[:,d:,:,:], out[:,d:,:,:]], 1) 34 | out = F.relu(out) 35 | return out 36 | 37 | 38 | class DPN(nn.Module): 39 | def __init__(self, cfg): 40 | super(DPN, self).__init__() 41 | in_planes, out_planes = cfg['in_planes'], cfg['out_planes'] 42 | num_blocks, dense_depth = cfg['num_blocks'], cfg['dense_depth'] 43 | 44 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(64) 46 | self.last_planes = 64 47 | self.layer1 = self._make_layer(in_planes[0], out_planes[0], num_blocks[0], dense_depth[0], stride=1) 48 | self.layer2 = self._make_layer(in_planes[1], out_planes[1], num_blocks[1], dense_depth[1], stride=2) 49 | self.layer3 = self._make_layer(in_planes[2], out_planes[2], num_blocks[2], dense_depth[2], stride=2) 50 | self.layer4 = self._make_layer(in_planes[3], out_planes[3], num_blocks[3], dense_depth[3], stride=2) 51 | self.linear = nn.Linear(out_planes[3]+(num_blocks[3]+1)*dense_depth[3], 10) 52 | 53 | def _make_layer(self, in_planes, out_planes, num_blocks, dense_depth, stride): 54 | strides = [stride] + [1]*(num_blocks-1) 55 | layers = [] 56 | for i,stride in enumerate(strides): 57 | layers.append(Bottleneck(self.last_planes, in_planes, out_planes, dense_depth, stride, i==0)) 58 | self.last_planes = out_planes + (i+2) * dense_depth 59 | return nn.Sequential(*layers) 60 | 61 | def forward(self, x): 62 | out = F.relu(self.bn1(self.conv1(x))) 63 | b1 = self.layer1(out) 64 | b2 = self.layer2(b1) 65 | b3 = self.layer3(b2) 66 | b4 = self.layer4(b3) 67 | pool = F.avg_pool2d(out, 4) 68 | out = pool.view(pool.size(0), -1) 69 | out = self.linear(out) 70 | return b1, b2, b3, b4, pool, out 71 | 72 | 73 | def DPN26(): 74 | cfg = { 75 | 'in_planes': (96,192,384,768), 76 | 'out_planes': (256,512,1024,2048), 77 | 'num_blocks': (2,2,2,2), 78 | 'dense_depth': (16,32,24,128) 79 | } 80 | return DPN(cfg) 81 | 82 | def DPN92(): 83 | cfg = { 84 | 'in_planes': (96,192,384,768), 85 | 'out_planes': (256,512,1024,2048), 86 | 'num_blocks': (3,4,20,3), 87 | 'dense_depth': (16,32,24,128) 88 | } 89 | return DPN(cfg) 90 | 91 | 92 | def test(): 93 | net = DPN92() 94 | x = torch.randn(1,3,32,32) 95 | y = net(x) 96 | print(y) 97 | 98 | # test() 99 | -------------------------------------------------------------------------------- /models/teacher/resnext.py: -------------------------------------------------------------------------------- 1 | '''ResNeXt in PyTorch. 2 | 3 | See the paper "Aggregated Residual Transformations for Deep Neural Networks" for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class Block(nn.Module): 11 | '''Grouped convolution block.''' 12 | expansion = 2 13 | 14 | def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1): 15 | super(Block, self).__init__() 16 | group_width = cardinality * bottleneck_width 17 | self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(group_width) 19 | self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False) 20 | self.bn2 = nn.BatchNorm2d(group_width) 21 | self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False) 22 | self.bn3 = nn.BatchNorm2d(self.expansion*group_width) 23 | 24 | self.shortcut = nn.Sequential() 25 | if stride != 1 or in_planes != self.expansion*group_width: 26 | self.shortcut = nn.Sequential( 27 | nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False), 28 | nn.BatchNorm2d(self.expansion*group_width) 29 | ) 30 | 31 | def forward(self, x): 32 | out = F.relu(self.bn1(self.conv1(x))) 33 | out = F.relu(self.bn2(self.conv2(out))) 34 | out = self.bn3(self.conv3(out)) 35 | out += self.shortcut(x) 36 | out = F.relu(out) 37 | return out 38 | 39 | 40 | class ResNeXt(nn.Module): 41 | def __init__(self, num_blocks, cardinality, bottleneck_width, num_classes=10, name=None): 42 | super(ResNeXt, self).__init__() 43 | self.model_name = name 44 | 45 | self.cardinality = cardinality 46 | self.bottleneck_width = bottleneck_width 47 | self.in_planes = 64 48 | 49 | self.conv1 = nn.Conv2d(3, 64, kernel_size=1, bias=False) 50 | self.bn1 = nn.BatchNorm2d(64) 51 | self.layer1 = self._make_layer(num_blocks[0], 1) 52 | self.layer2 = self._make_layer(num_blocks[1], 2) 53 | self.layer3 = self._make_layer(num_blocks[2], 2) 54 | # self.layer4 = self._make_layer(num_blocks[3], 2) 55 | 56 | self.linear = nn.Linear(cardinality*bottleneck_width*8, num_classes) 57 | 58 | self.output_size = cardinality * bottleneck_width * 8 59 | 60 | def _make_layer(self, num_blocks, stride): 61 | strides = [stride] + [1]*(num_blocks-1) 62 | layers = [] 63 | for stride in strides: 64 | layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride)) 65 | self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width 66 | # Increase bottleneck_width by 2 after each stage. 67 | self.bottleneck_width *= 2 68 | return nn.Sequential(*layers) 69 | 70 | def forward(self, x): 71 | out = F.relu(self.bn1(self.conv1(x))) 72 | b1 = self.layer1(out) 73 | b2 = self.layer2(b1) 74 | b3 = self.layer3(b2) 75 | # out = self.layer4(out) 76 | pool = F.avg_pool2d(b3, 8) 77 | out = pool.view(pool.size(0), -1) 78 | out = self.linear(out) 79 | return b1, b2, b3, pool, out 80 | 81 | 82 | def ResNeXt29_2x64d(num_classes=10): 83 | return ResNeXt(num_blocks=[3,3,3], cardinality=2, bottleneck_width=64, num_classes=num_classes, name='ResNeXt29_2x64d') 84 | 85 | def ResNeXt29_4x64d(num_classes=10): 86 | return ResNeXt(num_blocks=[3,3,3], cardinality=4, bottleneck_width=64, num_classes=num_classes, name='ResNeXt29_4x64d') 87 | 88 | def ResNeXt29_8x64d(num_classes=10): 89 | return ResNeXt(num_blocks=[3,3,3], cardinality=8, bottleneck_width=64, num_classes=num_classes, name='ResNeXt29_8x64d') 90 | 91 | def ResNeXt29_32x4d(num_classes=10): 92 | return ResNeXt(num_blocks=[3,3,3], cardinality=32, bottleneck_width=4, num_classes=num_classes, name='ResNeXt29_32x4d') 93 | 94 | def test_resnext(): 95 | net = ResNeXt29_2x64d() 96 | x = torch.randn(1,3,32,32) 97 | y = net(x) 98 | print(y.size()) 99 | 100 | # test_resnext() 101 | -------------------------------------------------------------------------------- /models/teacher/densenet.py: -------------------------------------------------------------------------------- 1 | '''DenseNet in PyTorch.''' 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | def __init__(self, in_planes, growth_rate): 11 | super(Bottleneck, self).__init__() 12 | self.bn1 = nn.BatchNorm2d(in_planes) 13 | self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(4*growth_rate) 15 | self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False) 16 | 17 | def forward(self, x): 18 | out = self.conv1(F.relu(self.bn1(x))) 19 | out = self.conv2(F.relu(self.bn2(out))) 20 | out = torch.cat([out,x], 1) 21 | return out 22 | 23 | 24 | class Transition(nn.Module): 25 | def __init__(self, in_planes, out_planes): 26 | super(Transition, self).__init__() 27 | self.bn = nn.BatchNorm2d(in_planes) 28 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False) 29 | 30 | def forward(self, x): 31 | out = self.conv(F.relu(self.bn(x))) 32 | out = F.avg_pool2d(out, 2) 33 | return out 34 | 35 | 36 | class DenseNet(nn.Module): 37 | def __init__(self, block, nblocks, growth_rate=12, reduction=0.5, num_classes=10, name=None): 38 | super(DenseNet, self).__init__() 39 | self.model_name = name 40 | self.growth_rate = growth_rate 41 | 42 | num_planes = 2*growth_rate 43 | self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False) 44 | 45 | self.dense1 = self._make_dense_layers(block, num_planes, nblocks[0]) 46 | num_planes += nblocks[0]*growth_rate 47 | out_planes = int(math.floor(num_planes*reduction)) 48 | self.trans1 = Transition(num_planes, out_planes) 49 | num_planes = out_planes 50 | 51 | self.dense2 = self._make_dense_layers(block, num_planes, nblocks[1]) 52 | num_planes += nblocks[1]*growth_rate 53 | out_planes = int(math.floor(num_planes*reduction)) 54 | self.trans2 = Transition(num_planes, out_planes) 55 | num_planes = out_planes 56 | 57 | self.dense3 = self._make_dense_layers(block, num_planes, nblocks[2]) 58 | num_planes += nblocks[2]*growth_rate 59 | out_planes = int(math.floor(num_planes*reduction)) 60 | self.trans3 = Transition(num_planes, out_planes) 61 | num_planes = out_planes 62 | 63 | self.dense4 = self._make_dense_layers(block, num_planes, nblocks[3]) 64 | num_planes += nblocks[3]*growth_rate 65 | 66 | self.bn = nn.BatchNorm2d(num_planes) 67 | self.linear = nn.Linear(num_planes, num_classes) 68 | 69 | self.output_size = num_planes 70 | 71 | def _make_dense_layers(self, block, in_planes, nblock): 72 | layers = [] 73 | for i in range(nblock): 74 | layers.append(block(in_planes, self.growth_rate)) 75 | in_planes += self.growth_rate 76 | return nn.Sequential(*layers) 77 | 78 | def forward(self, x): 79 | out = self.conv1(x) 80 | b1 = self.trans1(self.dense1(out)) 81 | b2 = self.trans2(self.dense2(b1)) 82 | b3 = self.trans3(self.dense3(b2)) 83 | b4 = self.dense4(b3) 84 | pool = F.avg_pool2d(F.relu(self.bn(b4)), 4) 85 | out = pool.view(pool.size(0), -1) 86 | out = self.linear(out) 87 | return b1, b2, b3, pool, out 88 | 89 | def DenseNet121(num_classes=10): 90 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=32, num_classes=num_classes, name='DenseNet121') 91 | 92 | def DenseNet169(num_classes=10): 93 | return DenseNet(Bottleneck, [6,12,32,32], growth_rate=32, num_classes=num_classes, name='DenseNet169') 94 | 95 | def DenseNet201(num_classes=10): 96 | return DenseNet(Bottleneck, [6,12,48,32], growth_rate=32, num_classes=num_classes, name='DenseNet201') 97 | 98 | def DenseNet161(num_classes=10): 99 | return DenseNet(Bottleneck, [6,12,36,24], growth_rate=48, num_classes=num_classes, name='DenseNet161') 100 | 101 | def densenet_cifar(num_classes=10): 102 | return DenseNet(Bottleneck, [6,12,24,16], growth_rate=12, num_classes=num_classes) 103 | 104 | def test(): 105 | net = densenet_cifar() 106 | x = torch.randn(1,3,32,32) 107 | y = net(x) 108 | print(y) 109 | 110 | # test() 111 | -------------------------------------------------------------------------------- /models/teacher/preact_resnet.py: -------------------------------------------------------------------------------- 1 | '''Pre-activation ResNet in PyTorch. 2 | 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class PreActBlock(nn.Module): 13 | '''Pre-activation version of the BasicBlock.''' 14 | expansion = 1 15 | 16 | def __init__(self, in_planes, planes, stride=1): 17 | super(PreActBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(in_planes) 19 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 22 | 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 26 | ) 27 | 28 | def forward(self, x): 29 | out = F.relu(self.bn1(x)) 30 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 31 | out = self.conv1(out) 32 | out = self.conv2(F.relu(self.bn2(out))) 33 | out += shortcut 34 | return out 35 | 36 | 37 | class PreActBottleneck(nn.Module): 38 | '''Pre-activation version of the original Bottleneck module.''' 39 | expansion = 4 40 | 41 | def __init__(self, in_planes, planes, stride=1): 42 | super(PreActBottleneck, self).__init__() 43 | self.bn1 = nn.BatchNorm2d(in_planes) 44 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 49 | 50 | if stride != 1 or in_planes != self.expansion*planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False) 53 | ) 54 | 55 | def forward(self, x): 56 | out = F.relu(self.bn1(x)) 57 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 58 | out = self.conv1(out) 59 | out = self.conv2(F.relu(self.bn2(out))) 60 | out = self.conv3(F.relu(self.bn3(out))) 61 | out += shortcut 62 | return out 63 | 64 | 65 | class PreActResNet(nn.Module): 66 | def __init__(self, block, num_blocks, num_classes=10): 67 | super(PreActResNet, self).__init__() 68 | self.in_planes = 64 69 | 70 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 71 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 72 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 73 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 74 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 75 | self.linear = nn.Linear(512*block.expansion, num_classes) 76 | 77 | def _make_layer(self, block, planes, num_blocks, stride): 78 | strides = [stride] + [1]*(num_blocks-1) 79 | layers = [] 80 | for stride in strides: 81 | layers.append(block(self.in_planes, planes, stride)) 82 | self.in_planes = planes * block.expansion 83 | return nn.Sequential(*layers) 84 | 85 | def forward(self, x): 86 | out = self.conv1(x) 87 | out = self.layer1(out) 88 | out = self.layer2(out) 89 | out = self.layer3(out) 90 | out = self.layer4(out) 91 | out = F.avg_pool2d(out, 4) 92 | out = out.view(out.size(0), -1) 93 | out = self.linear(out) 94 | return out 95 | 96 | 97 | def PreActResNet18(num_classes=10): 98 | return PreActResNet(PreActBlock, [2,2,2,2], num_classes=10) 99 | 100 | def PreActResNet34(num_classes=10): 101 | return PreActResNet(PreActBlock, [3,4,6,3], num_classes=10) 102 | 103 | def PreActResNet50(num_classes=10): 104 | return PreActResNet(PreActBottleneck, [3,4,6,3], num_classes=10) 105 | 106 | def PreActResNet101(num_classes=10): 107 | return PreActResNet(PreActBottleneck, [3,4,23,3], num_classes=10) 108 | 109 | def PreActResNet152(num_classes=10): 110 | return PreActResNet(PreActBottleneck, [3,8,36,3], num_classes=10) 111 | 112 | 113 | def test(): 114 | net = PreActResNet18() 115 | y = net((torch.randn(1,3,32,32))) 116 | print(y.size()) 117 | 118 | # test() 119 | -------------------------------------------------------------------------------- /models/teacher/resnet.py: -------------------------------------------------------------------------------- 1 | '''ResNet in PyTorch. 2 | For Pre-activation ResNet, see 'preact_resnet.py'. 3 | Reference: 4 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 5 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, in_planes, planes, stride=1): 16 | super(BasicBlock, self).__init__() 17 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | 22 | self.shortcut = nn.Sequential() 23 | if stride != 1 or in_planes != self.expansion*planes: 24 | self.shortcut = nn.Sequential( 25 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 26 | nn.BatchNorm2d(self.expansion*planes) 27 | ) 28 | 29 | def forward(self, x): 30 | out = F.relu(self.bn1(self.conv1(x))) 31 | out = self.bn2(self.conv2(out)) 32 | out += self.shortcut(x) 33 | out = F.relu(out) 34 | return out 35 | 36 | 37 | class Bottleneck(nn.Module): 38 | expansion = 4 39 | 40 | def __init__(self, in_planes, planes, stride=1): 41 | super(Bottleneck, self).__init__() 42 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 43 | self.bn1 = nn.BatchNorm2d(planes) 44 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 47 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 48 | 49 | self.shortcut = nn.Sequential() 50 | if stride != 1 or in_planes != self.expansion*planes: 51 | self.shortcut = nn.Sequential( 52 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 53 | nn.BatchNorm2d(self.expansion*planes) 54 | ) 55 | 56 | def forward(self, x): 57 | out = F.relu(self.bn1(self.conv1(x))) 58 | out = F.relu(self.bn2(self.conv2(out))) 59 | out = self.bn3(self.conv3(out)) 60 | out += self.shortcut(x) 61 | out = F.relu(out) 62 | return out 63 | 64 | 65 | class ResNet(nn.Module): 66 | def __init__(self, block, num_blocks, num_classes=10, name=None): 67 | super(ResNet, self).__init__() 68 | self.model_name = name 69 | self.in_planes = 64 70 | 71 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 72 | self.bn1 = nn.BatchNorm2d(64) 73 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 74 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 75 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 76 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 77 | self.linear = nn.Linear(512*block.expansion, num_classes) 78 | 79 | def _make_layer(self, block, planes, num_blocks, stride): 80 | strides = [stride] + [1]*(num_blocks-1) 81 | layers = [] 82 | for stride in strides: 83 | layers.append(block(self.in_planes, planes, stride)) 84 | self.in_planes = planes * block.expansion 85 | return nn.Sequential(*layers) 86 | 87 | def forward(self, x): 88 | out = F.relu(self.bn1(self.conv1(x))) 89 | b1 = self.layer1(out) 90 | b2 = self.layer2(b1) 91 | b3 = self.layer3(b2) 92 | b4 = self.layer4(b3) 93 | pool = F.avg_pool2d(b4, 4) 94 | out = pool.view(pool.size(0), -1) 95 | out = self.linear(out) 96 | return b1, b2, b3, pool, out 97 | 98 | 99 | def ResNet18(num_classes=10): 100 | return ResNet(BasicBlock, [2,2,2,2], num_classes=num_classes, name='ResNet18') 101 | 102 | def ResNet34(num_classes=10): 103 | return ResNet(BasicBlock, [3,4,6,3], num_classes=num_classes, name='ResNet34') 104 | 105 | def ResNet50(num_classes=10): 106 | return ResNet(Bottleneck, [3,4,6,3], num_classes=num_classes, name='ResNet50') 107 | 108 | def ResNet101(num_classes=10): 109 | return ResNet(Bottleneck, [3,4,23,3], num_classes=num_classes, name='ResNet101') 110 | 111 | def ResNet152(num_classes=10): 112 | return ResNet(Bottleneck, [3,8,36,3], num_classes=num_classes, name='ResNet152') 113 | 114 | 115 | def test(): 116 | net = ResNet18() 117 | y = net(torch.randn(1,3,32,32)) 118 | print(y.size()) 119 | 120 | # test() -------------------------------------------------------------------------------- /dataset/tinyimagenet.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | from torchvision import models,utils,datasets,transforms 3 | import numpy as np 4 | import sys 5 | import os 6 | from PIL import Image 7 | 8 | 9 | class TinyImageNet(Dataset): 10 | def __init__(self, root='../data/tiny-imagenet-200', train=True, transform = None): 11 | self.Train = train 12 | self.root_dir = root 13 | self.transform = transform 14 | self.train_dir = os.path.join(self.root_dir, "train") 15 | self.val_dir = os.path.join(self.root_dir, "val") 16 | 17 | if (self.Train): 18 | self._create_class_idx_dict_train() 19 | else: 20 | self._create_class_idx_dict_val() 21 | 22 | self._make_dataset(self.Train) 23 | 24 | words_file = os.path.join(self.root_dir, "words.txt") 25 | wnids_file = os.path.join(self.root_dir, "wnids.txt") 26 | 27 | self.set_nids = set() 28 | 29 | with open(wnids_file, 'r') as fo: 30 | data = fo.readlines() 31 | for entry in data: 32 | self.set_nids.add(entry.strip("\n")) 33 | 34 | self.class_to_label = {} 35 | with open(words_file, 'r') as fo: 36 | data = fo.readlines() 37 | for entry in data: 38 | words = entry.split("\t") 39 | if words[0] in self.set_nids: 40 | self.class_to_label[words[0]] = (words[1].strip("\n").split(","))[0] 41 | 42 | 43 | def _create_class_idx_dict_train(self): 44 | if sys.version_info >= (3,5): 45 | classes = [d.name for d in os.scandir(self.train_dir) if d.is_dir()] 46 | else: 47 | classes = [d for d in os.listdir(self.train_dir) if os.path.isdir(os.path.join(train_dir,d))] 48 | classes = sorted(classes) 49 | num_images = 0 50 | for root, dirs, files in os.walk(self.train_dir): 51 | for f in files: 52 | if f.endswith(".JPEG"): 53 | num_images = num_images + 1 54 | 55 | self.len_dataset = num_images; 56 | 57 | self.tgt_idx_to_class = {i:classes[i] for i in range(len(classes))} 58 | self.class_to_tgt_idx = {classes[i]:i for i in range(len(classes))} 59 | 60 | def _create_class_idx_dict_val(self): 61 | val_image_dir = os.path.join(self.val_dir, "images") 62 | if sys.version_info >= (3,5): 63 | images = [d.name for d in os.scandir(val_image_dir) if d.is_file()] 64 | else: 65 | images = [d for d in os.listdir(val_image_dir) if os.path.isfile(os.path.join(train_dir,d))] 66 | val_annotations_file = os.path.join(self.val_dir, "val_annotations.txt") 67 | self.val_img_to_class = {} 68 | set_of_classes = set() 69 | with open(val_annotations_file,'r') as fo: 70 | entry = fo.readlines() 71 | for data in entry: 72 | words = data.split("\t") 73 | self.val_img_to_class[words[0]] = words[1] 74 | set_of_classes.add(words[1]) 75 | 76 | self.len_dataset = len(list(self.val_img_to_class.keys())) 77 | classes = sorted(list(set_of_classes)) 78 | #self.idx_to_class = {i:self.val_img_to_class[images[i]] for i in range(len(images))} 79 | self.class_to_tgt_idx = {classes[i]:i for i in range(len(classes))} 80 | self.tgt_idx_to_class = {i:classes[i] for i in range(len(classes))} 81 | 82 | def _make_dataset(self, Train=True): 83 | self.images = [] 84 | if Train: 85 | img_root_dir = self.train_dir 86 | list_of_dirs = [target for target in self.class_to_tgt_idx.keys()] 87 | else: 88 | img_root_dir = self.val_dir 89 | list_of_dirs = ["images"] 90 | 91 | for tgt in list_of_dirs: 92 | dirs = os.path.join(img_root_dir, tgt) 93 | if not os.path.isdir(dirs): 94 | continue 95 | 96 | for root,_,files in sorted(os.walk(dirs)): 97 | for fname in sorted(files): 98 | if (fname.endswith(".JPEG")): 99 | path = os.path.join(root, fname) 100 | if Train: 101 | item = (path, self.class_to_tgt_idx[tgt]) 102 | else: 103 | item = (path, self.class_to_tgt_idx[self.val_img_to_class[fname]]) 104 | self.images.append(item) 105 | 106 | def return_label(self, idx): 107 | return [self.class_to_label[self.tgt_idx_to_class[i.item()]] for i in idx] 108 | 109 | 110 | def __len__(self): 111 | return self.len_dataset 112 | 113 | 114 | def __getitem__(self, idx): 115 | img_path, tgt = self.images[idx] 116 | with open(img_path,'rb') as f: 117 | sample = Image.open(img_path) 118 | sample = sample.convert('RGB') 119 | if self.transform is not None: 120 | sample = self.transform(sample) 121 | 122 | return sample, tgt -------------------------------------------------------------------------------- /models/teacher/resnet20.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Properly implemented ResNet-s for CIFAR10 as described in paper [1]. 3 | The implementation and structure of this file is hugely influenced by [2] 4 | which is implemented for ImageNet and doesn't have option A for identity. 5 | Moreover, most of the implementations on the web is copy-paste from 6 | torchvision's resnet and has wrong number of params. 7 | Proper ResNet-s for CIFAR10 (for fair comparision and etc.) has following 8 | number of layers and parameters: 9 | name | layers | params 10 | ResNet20 | 20 | 0.27M 11 | ResNet32 | 32 | 0.46M 12 | ResNet44 | 44 | 0.66M 13 | ResNet56 | 56 | 0.85M 14 | ResNet110 | 110 | 1.7M 15 | ResNet1202| 1202 | 19.4M 16 | which this implementation indeed has. 17 | Reference: 18 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 20 | [2] https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 21 | If you use this implementation in you work, please don't forget to mention the 22 | author, Yerlan Idelbayev. 23 | ''' 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | import torch.nn.init as init 28 | 29 | from torch.autograd import Variable 30 | 31 | __all__ = ['ResNet', 'ResNet20', 'ResNet32', 'ResNet44', 'ResNet56', 'ResNet110', 'ResNet1202'] 32 | 33 | def _weights_init(m): 34 | classname = m.__class__.__name__ 35 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): 36 | init.kaiming_normal(m.weight) 37 | 38 | class LambdaLayer(nn.Module): 39 | def __init__(self, lambd): 40 | super(LambdaLayer, self).__init__() 41 | self.lambd = lambd 42 | 43 | def forward(self, x): 44 | return self.lambd(x) 45 | 46 | 47 | class BasicBlock(nn.Module): 48 | expansion = 1 49 | 50 | def __init__(self, in_planes, planes, stride=1, option='A'): 51 | super(BasicBlock, self).__init__() 52 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 53 | self.bn1 = nn.BatchNorm2d(planes) 54 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 55 | self.bn2 = nn.BatchNorm2d(planes) 56 | 57 | self.shortcut = nn.Sequential() 58 | if stride != 1 or in_planes != planes: 59 | if option == 'A': 60 | """ 61 | For CIFAR10 ResNet paper uses option A. 62 | """ 63 | self.shortcut = LambdaLayer(lambda x: F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 64 | elif option == 'B': 65 | self.shortcut = nn.Sequential( 66 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 67 | nn.BatchNorm2d(self.expansion * planes) 68 | ) 69 | 70 | def forward(self, x): 71 | out = F.relu(self.bn1(self.conv1(x))) 72 | out = self.bn2(self.conv2(out)) 73 | out += self.shortcut(x) 74 | out = F.relu(out) 75 | return out 76 | 77 | 78 | class ResNet(nn.Module): 79 | def __init__(self, block, num_blocks, num_classes=10, name=None): 80 | super(ResNet, self).__init__() 81 | self.in_planes = 16 82 | self.model_name = name 83 | self.output_size = 64 84 | 85 | self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 86 | self.bn1 = nn.BatchNorm2d(16) 87 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 88 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 89 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 90 | self.linear = nn.Linear(64, num_classes) 91 | 92 | self.apply(_weights_init) 93 | 94 | def _make_layer(self, block, planes, num_blocks, stride): 95 | strides = [stride] + [1]*(num_blocks-1) 96 | layers = [] 97 | for stride in strides: 98 | layers.append(block(self.in_planes, planes, stride)) 99 | self.in_planes = planes * block.expansion 100 | 101 | return nn.Sequential(*layers) 102 | 103 | def forward(self, x): 104 | out = F.relu(self.bn1(self.conv1(x))) 105 | b1 = self.layer1(out) 106 | b2 = self.layer2(b1) 107 | b3 = self.layer3(b2) 108 | pool = F.avg_pool2d(b3, b3.size()[3]) 109 | out = pool.view(pool.size(0), -1) 110 | out = self.linear(out) 111 | return b1, b2, b3, pool, out 112 | 113 | 114 | def ResNet20(num_classes=10): 115 | return ResNet(BasicBlock, [3, 3, 3], num_classes, name='ResNet20') 116 | 117 | 118 | def ResNet32(num_classes=10): 119 | return ResNet(BasicBlock, [5, 5, 5], num_classes, name='ResNet32') 120 | 121 | 122 | def ResNet44(num_classes=10): 123 | return ResNet(BasicBlock, [7, 7, 7], num_classes, name='ResNet44') 124 | 125 | 126 | def ResNet56(num_classes=10): 127 | return ResNet(BasicBlock, [9, 9, 9], num_classes, name='ResNet56') 128 | 129 | 130 | def ResNet110(num_classes=10): 131 | return ResNet(BasicBlock, [18, 18, 18], num_classes, name='ResNet110') 132 | 133 | 134 | def ResNet1202(num_classes=10): 135 | return ResNet(BasicBlock, [200, 200, 200], num_classes, name='ResNet1202') 136 | 137 | 138 | def test(net): 139 | import numpy as np 140 | total_params = 0 141 | 142 | for x in filter(lambda p: p.requires_grad, net.parameters()): 143 | total_params += np.prod(x.data.numpy().shape) 144 | print("Total number of params", total_params) 145 | print("Total layers", len(list(filter(lambda p: p.requires_grad and len(p.data.size())>1, net.parameters())))) 146 | 147 | ''' 148 | if __name__ == "__main__": 149 | for net_name in __all__: 150 | if net_name.startswith('resnet'): 151 | print(net_name) 152 | test(globals()[net_name]()) 153 | print() 154 | ''' -------------------------------------------------------------------------------- /train_resnet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.parallel 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim 10 | import torch.utils.data 11 | import torchvision.transforms as transforms 12 | import torchvision.datasets as datasets 13 | from models import resnet20 14 | # # Teacher models 15 | # from models.teacher import * 16 | # # Student models 17 | # from models.student import * 18 | 19 | model_names = sorted(name for name in resnet20.__dict__ 20 | if not name.startswith("__") 21 | and name.startswith("ResNet") 22 | and callable(resnet20.__dict__[name])) 23 | 24 | parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR10 in pytorch') 25 | parser.add_argument('--arch', '-a', metavar='ARCH', default='ResNet32', 26 | choices=model_names, 27 | help='model architecture: ' + ' | '.join(model_names) + 28 | ' (default: ResNet32)') 29 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 30 | help='number of data loading workers (default: 4)') 31 | parser.add_argument('--epochs', default=200, type=int, metavar='N', 32 | help='number of total epochs to run') 33 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 34 | help='manual epoch number (useful on restarts)') 35 | parser.add_argument('-b', '--batch-size', default=128, type=int, 36 | metavar='N', help='mini-batch size (default: 128)') 37 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 38 | metavar='LR', help='initial learning rate') 39 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 40 | help='momentum') 41 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 42 | metavar='W', help='weight decay (default: 5e-4)') 43 | parser.add_argument('--print-freq', '-p', default=50, type=int, 44 | metavar='N', help='print frequency (default: 20)') 45 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 46 | help='path to latest checkpoint (default: none)') 47 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 48 | help='evaluate model on validation set') 49 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 50 | help='use pre-trained model') 51 | parser.add_argument('--half', dest='half', action='store_true', 52 | help='use half-precision(16-bit) ') 53 | parser.add_argument('--save-dir', dest='save_dir', 54 | help='The directory used to save the trained models', 55 | default='save_temp', type=str) 56 | parser.add_argument('--save-every', dest='save_every', 57 | help='Saves checkpoints at every specified number of epochs', 58 | type=int, default=20) 59 | best_prec1 = 0 60 | 61 | 62 | def main(config): 63 | global args, best_prec1 64 | args = parser.parse_args(config) 65 | 66 | 67 | # Check the save_dir exists or not 68 | if not os.path.exists(args.save_dir): 69 | os.makedirs(args.save_dir) 70 | 71 | model = torch.nn.DataParallel(resnet20.__dict__[args.arch]()) 72 | model.cuda() 73 | 74 | # optionally resume from a checkpoint 75 | if args.resume: 76 | if os.path.isfile(args.resume): 77 | print("=> loading checkpoint '{}'".format(args.resume)) 78 | checkpoint = torch.load(args.resume) 79 | args.start_epoch = checkpoint['epoch'] 80 | best_prec1 = checkpoint['best_prec1'] 81 | model.load_state_dict(checkpoint['state_dict']) 82 | print("=> loaded checkpoint '{}' (epoch {})" 83 | .format(args.evaluate, checkpoint['epoch'])) 84 | else: 85 | print("=> no checkpoint found at '{}'".format(args.resume)) 86 | 87 | cudnn.benchmark = True 88 | 89 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 90 | 91 | train_loader = torch.utils.data.DataLoader( 92 | datasets.CIFAR10(root='../data', train=True, transform=transforms.Compose([ 93 | transforms.RandomHorizontalFlip(), 94 | transforms.RandomCrop(32, 4), 95 | transforms.ToTensor(), 96 | normalize, 97 | ]), download=True), 98 | batch_size=args.batch_size, shuffle=True, 99 | num_workers=args.workers, pin_memory=True) 100 | 101 | val_loader = torch.utils.data.DataLoader( 102 | datasets.CIFAR10(root='../data', train=False, transform=transforms.Compose([ 103 | transforms.ToTensor(), 104 | normalize, 105 | ])), 106 | batch_size=128, shuffle=False, 107 | num_workers=args.workers, pin_memory=True) 108 | 109 | # define loss function (criterion) and pptimizer 110 | criterion = nn.CrossEntropyLoss().cuda() 111 | 112 | if args.half: 113 | model.half() 114 | criterion.half() 115 | 116 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 117 | momentum=args.momentum, 118 | weight_decay=args.weight_decay) 119 | 120 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], last_epoch=args.start_epoch - 1) 121 | 122 | if args.arch in ['ResNet1202', 'ResNet110']: 123 | # for resnet1202 original paper uses lr=0.01 for first 400 minibatches for warm-up 124 | # then switch back. In this implementation it will correspond for first epoch. 125 | for param_group in optimizer.param_groups: 126 | param_group['lr'] = args.lr * 0.1 127 | 128 | if args.evaluate: 129 | validate(val_loader, model, criterion) 130 | return 131 | 132 | for epoch in range(args.start_epoch, args.epochs): 133 | 134 | # train for one epoch 135 | print('current lr {:.5e}'.format(optimizer.param_groups[0]['lr'])) 136 | train(train_loader, model, criterion, optimizer, epoch) 137 | lr_scheduler.step() 138 | 139 | # evaluate on validation set 140 | prec1 = validate(val_loader, model, criterion) 141 | 142 | # remember best prec@1 and save checkpoint 143 | is_best = prec1 > best_prec1 144 | best_prec1 = max(prec1, best_prec1) 145 | 146 | if epoch > 0 and epoch % args.save_every == 0: 147 | save_checkpoint({ 148 | 'epoch': epoch + 1, 149 | 'state_dict': model.state_dict(), 150 | 'best_prec1': best_prec1, 151 | }, is_best, filename=os.path.join(args.save_dir, 'checkpoint.th')) 152 | 153 | save_checkpoint({ 154 | 'state_dict': model.state_dict(), 155 | 'best_prec1': best_prec1, 156 | }, is_best, filename=os.path.join(args.save_dir, 'model.th')) 157 | 158 | 159 | def train(train_loader, model, criterion, optimizer, epoch): 160 | """ 161 | Run one train epoch 162 | """ 163 | batch_time = AverageMeter() 164 | data_time = AverageMeter() 165 | losses = AverageMeter() 166 | top1 = AverageMeter() 167 | 168 | # switch to train mode 169 | model.train() 170 | 171 | end = time.time() 172 | for i, (input, target) in enumerate(train_loader): 173 | 174 | # measure data loading time 175 | data_time.update(time.time() - end) 176 | 177 | target = target.cuda() 178 | input_var = torch.autograd.Variable(input).cuda() 179 | target_var = torch.autograd.Variable(target) 180 | if args.half: 181 | input_var = input_var.half() 182 | 183 | # compute output 184 | output = model(input_var) 185 | loss = criterion(output, target_var) 186 | 187 | # compute gradient and do SGD step 188 | optimizer.zero_grad() 189 | loss.backward() 190 | optimizer.step() 191 | 192 | output = output.float() 193 | loss = loss.float() 194 | # measure accuracy and record loss 195 | prec1 = accuracy(output.data, target)[0] 196 | losses.update(loss.item(), input.size(0)) 197 | top1.update(prec1, input.size(0)) 198 | 199 | # measure elapsed time 200 | batch_time.update(time.time() - end) 201 | end = time.time() 202 | 203 | if i % args.print_freq == 0: 204 | print('Epoch: [{0}][{1}/{2}]\t' 205 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 206 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 207 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 208 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 209 | epoch, i, len(train_loader), batch_time=batch_time, 210 | data_time=data_time, loss=losses, top1=top1)) 211 | 212 | 213 | def validate(val_loader, model, criterion): 214 | """ 215 | Run evaluation 216 | """ 217 | batch_time = AverageMeter() 218 | losses = AverageMeter() 219 | top1 = AverageMeter() 220 | 221 | # switch to evaluate mode 222 | model.eval() 223 | 224 | end = time.time() 225 | with torch.no_grad(): 226 | for i, (input, target) in enumerate(val_loader): 227 | target = target.cuda() 228 | input_var = torch.autograd.Variable(input).cuda() 229 | target_var = torch.autograd.Variable(target) 230 | 231 | if args.half: 232 | input_var = input_var.half() 233 | 234 | # compute output 235 | output = model(input_var) 236 | loss = criterion(output, target_var) 237 | 238 | output = output.float() 239 | loss = loss.float() 240 | 241 | # measure accuracy and record loss 242 | prec1 = accuracy(output.data, target)[0] 243 | losses.update(loss.item(), input.size(0)) 244 | top1.update(prec1, input.size(0)) 245 | 246 | # measure elapsed time 247 | batch_time.update(time.time() - end) 248 | end = time.time() 249 | 250 | if i % args.print_freq == 0: 251 | print('Test: [{0}/{1}]\t' 252 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 253 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 254 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 255 | i, len(val_loader), batch_time=batch_time, loss=losses, 256 | top1=top1)) 257 | 258 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) 259 | 260 | return top1.avg 261 | 262 | def save_checkpoint(state, is_best, filename='checkpoint/resnet20.pth.tar'): 263 | """ 264 | Save the training model 265 | """ 266 | torch.save(state, filename) 267 | 268 | class AverageMeter(object): 269 | """Computes and stores the average and current value""" 270 | def __init__(self): 271 | self.reset() 272 | 273 | def reset(self): 274 | self.val = 0 275 | self.avg = 0 276 | self.sum = 0 277 | self.count = 0 278 | 279 | def update(self, val, n=1): 280 | self.val = val 281 | self.sum += val * n 282 | self.count += n 283 | self.avg = self.sum / self.count 284 | 285 | 286 | def accuracy(output, target, topk=(1,)): 287 | """Computes the precision@k for the specified values of k""" 288 | maxk = max(topk) 289 | batch_size = target.size(0) 290 | 291 | _, pred = output.topk(maxk, 1, True, True) 292 | pred = pred.t() 293 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 294 | 295 | res = [] 296 | for k in topk: 297 | correct_k = correct[:k].view(-1).float().sum(0) 298 | res.append(correct_k.mul_(100.0 / batch_size)) 299 | return res 300 | 301 | 302 | if __name__ == '__main__': 303 | config = ['--arch', 'ResNet20', '--epochs', '200'] 304 | main(config) -------------------------------------------------------------------------------- /FitNet_distill_hook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2019-06-21T02:18:39.651532Z", 9 | "start_time": "2019-06-21T02:18:39.184663Z" 10 | } 11 | }, 12 | "outputs": [ 13 | { 14 | "name": "stderr", 15 | "output_type": "stream", 16 | "text": [ 17 | "WARNING:root:Setting up a new session...\n" 18 | ] 19 | } 20 | ], 21 | "source": [ 22 | "from __future__ import print_function\n", 23 | "import os\n", 24 | "import time\n", 25 | "import logging\n", 26 | "import argparse\n", 27 | "from visdom import Visdom\n", 28 | "import numpy as np\n", 29 | "import torch\n", 30 | "import torch.nn as nn\n", 31 | "import torch.nn.functional as F\n", 32 | "import torch.optim as optim\n", 33 | "from torch.autograd import Variable\n", 34 | "from torch.utils.data import DataLoader\n", 35 | "from torchvision import datasets, transforms\n", 36 | "\n", 37 | "# Teacher models\n", 38 | "from models.teacher import *\n", 39 | "\n", 40 | "# Student models\n", 41 | "from models.student import *\n", 42 | "\n", 43 | "\n", 44 | "start_time = time.time()\n", 45 | "# os.makedirs('./checkpoint', exist_ok=True)\n", 46 | "\n", 47 | "# Training settings\n", 48 | "parser = argparse.ArgumentParser(description='PyTorch Distill Example')\n", 49 | "parser.add_argument('--teacher', type=str, default='VGG19', help='teacher net: AlexNet, VGG11/13/16/19, GoogLeNet')\n", 50 | "parser.add_argument('--student', type=str, default='FitNet11', help='student net: LeNet5, ')\n", 51 | "parser.add_argument('--T', type=float, default=20.0, metavar='Temputure', help='Temputure for distillation')\n", 52 | "parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size for training')\n", 53 | "parser.add_argument('--test_batch_size', type=int, default=128, metavar='N', help='input test batch size for training')\n", 54 | "parser.add_argument('--epochs', type=int, default=20, metavar='N', help='number of epochs to train (default: 20)')\n", 55 | "parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)')\n", 56 | "parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.5)')\n", 57 | "parser.add_argument('--cuda', action='store_true', default=torch.cuda.is_available(), help='use CUDA training')\n", 58 | "parser.add_argument('--log_interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status')\n", 59 | "\n", 60 | "config = ['--epochs', '50', '--teacher', 'Hint7', '--student', 'FitNet11', '--T', '10', '--cuda']\n", 61 | "args = parser.parse_args(config)\n", 62 | "\n", 63 | "device = 'cuda:0' if args.cuda and torch.cuda.is_available() else 'cpu'\n", 64 | "\n", 65 | "# logging\n", 66 | "logfile = './checkpoint/distill_' + args.teacher + '_' + args.student + '.log'\n", 67 | "if os.path.exists(logfile):\n", 68 | " os.remove(logfile)\n", 69 | "\n", 70 | "def log_out(info):\n", 71 | " f = open(logfile, mode='a')\n", 72 | " f.write(info)\n", 73 | " f.write('\\n')\n", 74 | " f.close()\n", 75 | " print(info)\n", 76 | " \n", 77 | "# visualizer\n", 78 | "vis = Visdom(env='distill')\n", 79 | "loss_win = vis.line(\n", 80 | " X=np.array([0]),\n", 81 | " Y=np.array([0]),\n", 82 | " opts=dict(\n", 83 | " title='train loss',\n", 84 | " xtickmin=0,\n", 85 | "# xtickmax=1,\n", 86 | " xtickstep=5,\n", 87 | " ytickmin=0,\n", 88 | "# ytickmax=1,\n", 89 | " ytickstep=0.5,\n", 90 | " markers=True,\n", 91 | " markersymbol='dot',\n", 92 | " markersize=5,\n", 93 | " ),\n", 94 | " name=\"loss\"\n", 95 | ")\n", 96 | " \n", 97 | "acc_win = vis.line(\n", 98 | " X=np.column_stack((0, 0)),\n", 99 | " Y=np.column_stack((0, 0)),\n", 100 | " opts=dict(\n", 101 | " title='ACC',\n", 102 | " xtickmin=0,\n", 103 | " xtickstep=5,\n", 104 | " ytickmin=0,\n", 105 | " ytickmax=100,\n", 106 | " markers=True,\n", 107 | " markersymbol='dot',\n", 108 | " markersize=5,\n", 109 | " legend=['train_acc', 'test_acc']\n", 110 | " ),\n", 111 | " name=\"acc\"\n", 112 | ")\n", 113 | "\n", 114 | "# weights init\n", 115 | "def weights_init_normal(m):\n", 116 | " classname = m.__class__.__name__\n", 117 | " if classname.find('Conv') != -1:\n", 118 | " nn.init.normal_(m.weight.data, 0.0, 0.02)\n", 119 | " elif classname.find(\"BatchNorm2d\") != -1:\n", 120 | " nn.init.normal_(m.weight.data, 1.0, 0.02)\n", 121 | " nn.init.constant_(m.bias.data, 0.0)\n", 122 | " elif classname.find('linear') != -1:\n", 123 | " nn.init.normal_(m.weight.data, 1.0, 0.02)\n", 124 | " nn.init.constant_(m.bias.data, 0.0)\n" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "# class ConvRegressor(nn.Module):\n", 134 | "# def __init__(self, teacher, hint_layer, student, guided_layer):\n", 135 | "# self.hint_layer = teacher.\n", 136 | " " 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 4, 142 | "metadata": { 143 | "ExecuteTime": { 144 | "end_time": "2019-06-21T08:25:23.258466Z", 145 | "start_time": "2019-06-21T08:25:22.836761Z" 146 | }, 147 | "code_folding": [] 148 | }, 149 | "outputs": [ 150 | { 151 | "ename": "FileNotFoundError", 152 | "evalue": "[Errno 2] No such file or directory: './checkpoint/Hint7_cifar10.pth'", 153 | "output_type": "error", 154 | "traceback": [ 155 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 156 | "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", 157 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mteacher_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mteacher\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mteacher_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_state_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'./checkpoint/'\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mteacher\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m'_cifar10.pth'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mst_model\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstudent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mst_model\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweights_init_normal\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# init student\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 158 | "\u001b[0;32m~/Dev/anaconda3/envs/py37/lib/python3.7/site-packages/torch/serialization.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, **pickle_load_args)\u001b[0m\n\u001b[1;32m 380\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mversion_info\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0municode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 381\u001b[0m \u001b[0mnew_fd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 382\u001b[0;31m \u001b[0mf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 383\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mversion_info\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m3\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpathlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 384\u001b[0m \u001b[0mnew_fd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 159 | "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: './checkpoint/Hint7_cifar10.pth'" 160 | ] 161 | } 162 | ], 163 | "source": [ 164 | "teacher_model = eval(args.teacher)().to(device)\n", 165 | "teacher_model.load_state_dict(torch.load('./checkpoint/' + args.teacher + '_cifar10.pth'))\n", 166 | "st_model = eval(args.student)().to(device)\n", 167 | "st_model.apply(weights_init_normal) # init student\n", 168 | "\n", 169 | "st_features = None\n", 170 | "te_features = None\n", 171 | "\n", 172 | "def st_hook(module, input, output):\n", 173 | " '''把这层的输出到features中'''\n", 174 | " global st_features\n", 175 | " st_features = output.data\n", 176 | " \n", 177 | "def te_hook(module, input, output):\n", 178 | " '''把这层的输出拷贝到features中'''\n", 179 | " global te_features\n", 180 | " te_features = output.data\n", 181 | "\n", 182 | "class Regressor(nn.Module):\n", 183 | " def __init__(self):\n", 184 | " super(Regressor,self).__init__()\n", 185 | " # torch.Size([128, 512, 16, 16]) -> torch.Size([128, 80, 4, 4]) \n", 186 | " self.features = nn.Sequential(\n", 187 | " nn.Conv2d(2048, 512, 3, 1, 1), # ch: 512 -> 256 \n", 188 | " nn.BatchNorm2d(512),\n", 189 | " nn.ReLU(inplace=True),\n", 190 | " nn.MaxPool2d(kernel_size=2, stride=2), # size: 16 -> 8\n", 191 | " \n", 192 | " nn.Conv2d(512, 128, 3, 1, 1), # ch: 256 -> 128\n", 193 | " nn.BatchNorm2d(128),\n", 194 | " nn.ReLU(inplace=True),\n", 195 | " nn.MaxPool2d(kernel_size=2, stride=2), # size: 8 -> 4\n", 196 | " \n", 197 | " nn.Conv2d(128, 80, 3, 1, 1), # ch: 128 -> 80\n", 198 | " nn.BatchNorm2d(128),\n", 199 | " nn.ReLU(inplace=True)\n", 200 | " )\n", 201 | " \n", 202 | " def forward(self, x):\n", 203 | " return self.features(x)\n", 204 | "\n", 205 | "regressor = Regressor().to(device)\n", 206 | "st_model.features[15].register_forward_hook(st_hook)\n", 207 | "teacher_model.features[11].register_forward_hook(te_hook)\n", 208 | "\n", 209 | "# data\n", 210 | "train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()])\n", 211 | "test_transform = transforms.Compose([transforms.ToTensor()])\n", 212 | "train_set = datasets.CIFAR10(root='../data', train=True, download=True, transform=train_transform)\n", 213 | "test_set = datasets.CIFAR10(root='../data', train=False, download=False, transform=test_transform)\n", 214 | "train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)\n", 215 | "test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False)\n", 216 | "\n", 217 | "optimizer = optim.SGD(st_model.parameters(), lr=args.lr, momentum=args.momentum)\n", 218 | "optimizer_r = optim.SGD(regressor.parameters(), lr=args.lr, momentum=args.momentum)\n", 219 | "\n", 220 | "def distillation(y, labels, teacher_scores, T, alpha):\n", 221 | " return nn.KLDivLoss()(F.log_softmax(y/T), F.softmax(teacher_scores/T)) * (T*T * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)\n", 222 | "\n", 223 | "\n", 224 | "# guided train\n", 225 | "def guided_train(model, loss_fn_guided=nn.MSELoss()):\n", 226 | " model.train()\n", 227 | " teacher_model.eval()\n", 228 | " guided_loss = None\n", 229 | " for epoch in range(10):\n", 230 | " for batch_idx, (data, target) in enumerate(train_loader):\n", 231 | " data, target = data.to(device), target.to(device)\n", 232 | " optimizer.zero_grad()\n", 233 | " optimizer_r.zero_grad()\n", 234 | " model.forward(data)\n", 235 | " teacher_model(data).detach()\n", 236 | " te_output = regressor(te_features)\n", 237 | "# print(st_features.size(), te_hint.size())\n", 238 | " \n", 239 | " guided_loss = loss_fn_guided(te_output, st_features)\n", 240 | " guided_loss.backward()\n", 241 | " optimizer_r.step()\n", 242 | " optimizer.step()\n", 243 | " print('guided_epoch:[{}]\\tLoss:{:.4f}'.format(epoch, guided_loss.item()))\n", 244 | "\n", 245 | "\n", 246 | "def train(epoch, model, loss_fn):\n", 247 | " model.train()\n", 248 | " teacher_model.eval()\n", 249 | " loss = None\n", 250 | " for batch_idx, (data, target) in enumerate(train_loader):\n", 251 | " data, target = data.to(device), target.to(device)\n", 252 | " optimizer.zero_grad()\n", 253 | " output = model(data)\n", 254 | " teacher_output = teacher_model(data)\n", 255 | " teacher_output = teacher_output.detach()\n", 256 | "# print(st_features.size())\n", 257 | " # teacher_output = Variable(teacher_output.data, requires_grad=False) #alternative approach to load teacher_output\n", 258 | " loss = loss_fn(output, target, teacher_output, T=args.T, alpha=0.6)\n", 259 | " loss.backward()\n", 260 | " optimizer.step()\n", 261 | " if batch_idx % args.log_interval == 0:\n", 262 | " log_out('Train Epoch: {} [{}/{} ({:.4f}%)]\\tLoss: {:.6f}'.format(\n", 263 | " epoch, batch_idx * len(data), len(train_loader.dataset),\n", 264 | " 100. * batch_idx / len(train_loader), loss.item()))\n", 265 | " return loss.item()\n", 266 | "\n", 267 | "def train_evaluate(model):\n", 268 | " model.eval()\n", 269 | " train_loss = 0\n", 270 | " correct = 0\n", 271 | " with torch.no_grad():\n", 272 | " for data, target in train_loader:\n", 273 | " data, target = data.to(device), target.to(device)\n", 274 | " output = model(data)\n", 275 | " train_loss += F.cross_entropy(output, target).item() # sum up batch loss\n", 276 | " pred = output.data.max(1, keepdim=True)[1]\n", 277 | " correct += pred.eq(target.data.view_as(pred)).cpu().sum()\n", 278 | "\n", 279 | " log_out('\\nTrain set: Average loss: {:.4f}, Accuracy: {}/{} ({:.4f}%)\\n'.format(\n", 280 | " train_loss, correct, len(train_loader.dataset), \n", 281 | " 100. * correct / len(train_loader.dataset)))\n", 282 | " return 100. * correct / len(train_loader.dataset)\n", 283 | "\n", 284 | "def test(model):\n", 285 | " model.eval()\n", 286 | " test_loss = 0\n", 287 | " correct = 0\n", 288 | " with torch.no_grad():\n", 289 | " for data, target in test_loader:\n", 290 | " data, target = data.to(device), target.to(device)\n", 291 | " output = model(data)\n", 292 | " # test_loss += F.cross_entropy(output, target).item() # sum up batch loss\n", 293 | " pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability\n", 294 | " correct += pred.eq(target.data.view_as(pred)).cpu().sum()\n", 295 | "\n", 296 | " test_loss /= len(test_loader.dataset)\n", 297 | " log_out('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.4f}%)\\n'.format(\n", 298 | " test_loss, correct, len(test_loader.dataset),\n", 299 | " 100. * correct / len(test_loader.dataset)))\n", 300 | " return 100. * correct / len(test_loader.dataset)\n", 301 | "\n", 302 | "print('StudentNet:\\n')\n", 303 | "print(st_model)\n", 304 | "guided_train(st_model)\n", 305 | "for epoch in range(1, args.epochs + 1):\n", 306 | " train_loss = train(epoch, st_model, loss_fn=distillation)\n", 307 | " # visaulize loss\n", 308 | " vis.line(np.array([train_loss]), np.array([epoch]), loss_win, update=\"append\")\n", 309 | " train_acc = train_evaluate(st_model)\n", 310 | " test_acc = test(st_model)\n", 311 | " vis.line(np.column_stack((train_acc, test_acc)), np.column_stack((epoch, epoch)), acc_win, update=\"append\")\n", 312 | "\n", 313 | "\n", 314 | "torch.save(st_model.state_dict(), './checkpoint/' + args.teacher + '_distill_' + args.student + '.pth')\n", 315 | "# the_model = Net()\n", 316 | "# the_model.load_state_dict(torch.load('student.pth.tar'))\n", 317 | "\n", 318 | "# test(the_model)\n", 319 | "# for data, target in test_loader:\n", 320 | "# data, target = Variable(data, volatile=True), Variable(target)\n", 321 | "# teacher_out = the_model(data)\n", 322 | "# print(teacher_out)\n", 323 | "log_out(\"--- {:.3f} seconds ---\".format(time.time() - start_time))\n" 324 | ] 325 | } 326 | ], 327 | "metadata": { 328 | "kernelspec": { 329 | "display_name": "Python [conda env:py37] *", 330 | "language": "python", 331 | "name": "conda-env-py37-py" 332 | }, 333 | "language_info": { 334 | "codemirror_mode": { 335 | "name": "ipython", 336 | "version": 3 337 | }, 338 | "file_extension": ".py", 339 | "mimetype": "text/x-python", 340 | "name": "python", 341 | "nbconvert_exporter": "python", 342 | "pygments_lexer": "ipython3", 343 | "version": "3.7.1" 344 | } 345 | }, 346 | "nbformat": 4, 347 | "nbformat_minor": 2 348 | } 349 | -------------------------------------------------------------------------------- /RKD.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "start_time": "2019-08-19T01:43:26.329Z" 9 | } 10 | }, 11 | "outputs": [ 12 | { 13 | "name": "stderr", 14 | "output_type": "stream", 15 | "text": [ 16 | "/home/data/yaliu/jupyterbooks/multi-KD/models/teacher/resnet20.py:36: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.\n", 17 | " init.kaiming_normal(m.weight)\n", 18 | "WARNING:root:Setting up a new session...\n" 19 | ] 20 | }, 21 | { 22 | "name": "stdout", 23 | "output_type": "stream", 24 | "text": [ 25 | "Files already downloaded and verified\n", 26 | "StudentNet:\n", 27 | "\n", 28 | "ResNet(\n", 29 | " (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 30 | " (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 31 | " (layer1): Sequential(\n", 32 | " (0): BasicBlock(\n", 33 | " (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 34 | " (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 35 | " (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 36 | " (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 37 | " (shortcut): Sequential()\n", 38 | " )\n", 39 | " (1): BasicBlock(\n", 40 | " (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 41 | " (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 42 | " (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 43 | " (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 44 | " (shortcut): Sequential()\n", 45 | " )\n", 46 | " (2): BasicBlock(\n", 47 | " (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 48 | " (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 49 | " (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 50 | " (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 51 | " (shortcut): Sequential()\n", 52 | " )\n", 53 | " )\n", 54 | " (layer2): Sequential(\n", 55 | " (0): BasicBlock(\n", 56 | " (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 57 | " (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 58 | " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 59 | " (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 60 | " (shortcut): LambdaLayer()\n", 61 | " )\n", 62 | " (1): BasicBlock(\n", 63 | " (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 64 | " (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 65 | " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 66 | " (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 67 | " (shortcut): Sequential()\n", 68 | " )\n", 69 | " (2): BasicBlock(\n", 70 | " (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 71 | " (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 72 | " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 73 | " (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 74 | " (shortcut): Sequential()\n", 75 | " )\n", 76 | " )\n", 77 | " (layer3): Sequential(\n", 78 | " (0): BasicBlock(\n", 79 | " (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 80 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 81 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 82 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 83 | " (shortcut): LambdaLayer()\n", 84 | " )\n", 85 | " (1): BasicBlock(\n", 86 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 87 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 88 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 89 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 90 | " (shortcut): Sequential()\n", 91 | " )\n", 92 | " (2): BasicBlock(\n", 93 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 94 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 95 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 96 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 97 | " (shortcut): Sequential()\n", 98 | " )\n", 99 | " )\n", 100 | " (linear): Linear(in_features=64, out_features=100, bias=True)\n", 101 | ")\n", 102 | "\n", 103 | "===> epoch: 1/200\n", 104 | "current lr 1.00000e-01\n", 105 | "Training:\n", 106 | "[0/391]\tTime 0.062 (0.062)\tData 0.020 (0.020)\tLoss 6.4279 (6.4279)\tPrec@1 0.781 (0.781)\n" 107 | ] 108 | }, 109 | { 110 | "name": "stderr", 111 | "output_type": "stream", 112 | "text": [ 113 | "/home/yaliu/Dev/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/_reduction.py:15: UserWarning: reduction='elementwise_mean' is deprecated, please use reduction='mean' instead.\n", 114 | " warnings.warn(\"reduction='elementwise_mean' is deprecated, please use reduction='mean' instead.\")\n" 115 | ] 116 | }, 117 | { 118 | "name": "stdout", 119 | "output_type": "stream", 120 | "text": [ 121 | "[40/391]\tTime 0.055 (0.056)\tData 0.018 (0.019)\tLoss 4.5492 (4.9636)\tPrec@1 3.906 (3.030)\n" 122 | ] 123 | } 124 | ], 125 | "source": [ 126 | "from __future__ import print_function\n", 127 | "import os\n", 128 | "import time\n", 129 | "import logging\n", 130 | "import argparse\n", 131 | "import numpy as np\n", 132 | "from visdom import Visdom\n", 133 | "from PIL import Image\n", 134 | "import torch\n", 135 | "import torch.nn as nn\n", 136 | "import torch.nn.functional as F\n", 137 | "import torch.optim as optim\n", 138 | "from torch.autograd import Variable\n", 139 | "from torch.utils.data import DataLoader\n", 140 | "from torchvision import datasets, transforms\n", 141 | "from utils import *\n", 142 | "from metric.loss import FitNet, AttentionTransfer, RKdAngle, RkdDistance\n", 143 | "\n", 144 | "# Teacher models:\n", 145 | "# VGG11/VGG13/VGG16/VGG19, GoogLeNet, AlxNet, ResNet18, ResNet34, \n", 146 | "# ResNet50, ResNet101, ResNet152, ResNeXt29_2x64d, ResNeXt29_4x64d, \n", 147 | "# ResNeXt29_8x64d, ResNeXt29_32x64d, PreActResNet18, PreActResNet34, \n", 148 | "# PreActResNet50, PreActResNet101, PreActResNet152, \n", 149 | "# DenseNet121, DenseNet161, DenseNet169, DenseNet201, \n", 150 | "import models\n", 151 | "\n", 152 | "# Student models:\n", 153 | "# myNet, LeNet, FitNet\n", 154 | "\n", 155 | "start_time = time.time()\n", 156 | "\n", 157 | "# Training settings\n", 158 | "parser = argparse.ArgumentParser(description='PyTorch LR_adaptive_AT')\n", 159 | "\n", 160 | "parser.add_argument('--dataset',\n", 161 | " choices=['CIFAR10',\n", 162 | " 'CIFAR100'\n", 163 | " ],\n", 164 | " default='CIFAR10')\n", 165 | "parser.add_argument('--teacher',\n", 166 | " choices=['ResNet32',\n", 167 | " 'ResNet50',\n", 168 | " 'ResNet56',\n", 169 | " 'ResNet110'\n", 170 | " ],\n", 171 | " default='ResNet110')\n", 172 | "parser.add_argument('--student',\n", 173 | " choices=['ResNet20',\n", 174 | " 'myNet'\n", 175 | " ],\n", 176 | " default='ResNet20')\n", 177 | "parser.add_argument('--dist_ratio', default=1, type=float)\n", 178 | "parser.add_argument('--angle_ratio', default=2, type=float)\n", 179 | "parser.add_argument('--at_ratio', default=1, type=float)\n", 180 | "\n", 181 | "parser.add_argument('--n_class', type=int, default=100, metavar='N', help='num of classes')\n", 182 | "parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size for training')\n", 183 | "parser.add_argument('--test_batch_size', type=int, default=128, metavar='N', help='input test batch size for training')\n", 184 | "parser.add_argument('--epochs', type=int, default=20, metavar='N', help='number of epochs to train (default: 20)')\n", 185 | "parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 0.01)')\n", 186 | "parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.5)')\n", 187 | "parser.add_argument('--device', default='cuda:1', type=str, help='device: cuda or cpu')\n", 188 | "parser.add_argument('--print_freq', type=int, default=40, metavar='N', help='how many batches to wait before logging training status')\n", 189 | "\n", 190 | "config = ['--dataset', 'CIFAR100', '--epochs', '200', '--at_ratio', '1', '--device', 'cuda:0']\n", 191 | "args = parser.parse_args(config)\n", 192 | "\n", 193 | "device = args.device if torch.cuda.is_available() else 'cpu'\n", 194 | "load_dir = './checkpoint/' + args.dataset + '/'\n", 195 | "\n", 196 | "# teacher model\n", 197 | "te_model = getattr(models, args.teacher)(num_classes=args.n_class)\n", 198 | "te_model.load_state_dict(torch.load(load_dir + te_model.model_name + '.pth'))\n", 199 | "te_model.to(device)\n", 200 | "te_model.eval() # eval mode\n", 201 | "\n", 202 | "st_model = getattr(models, args.student)(num_classes=args.n_class) # args.student()\n", 203 | "st_model.to(device)\n", 204 | "\n", 205 | "# logging\n", 206 | "logfile = load_dir + 'RKD_' + st_model.model_name + '.log'\n", 207 | "if os.path.exists(logfile):\n", 208 | " os.remove(logfile)\n", 209 | "def log_out(info):\n", 210 | " f = open(logfile, mode='a')\n", 211 | " f.write(info)\n", 212 | " f.write('\\n')\n", 213 | " f.close()\n", 214 | " print(info)\n", 215 | " \n", 216 | "# visualizer\n", 217 | "vis = Visdom(env='distill')\n", 218 | "loss_win = vis.line(\n", 219 | " X=np.array([0]),\n", 220 | " Y=np.array([0]),\n", 221 | " opts=dict(\n", 222 | " title='RKD Loss',\n", 223 | " xlabel='epoch',\n", 224 | " xtickmin=0,\n", 225 | " ylabel='loss',\n", 226 | " ytickmin=0,\n", 227 | " ytickstep=0.5,\n", 228 | " ),\n", 229 | " name=\"loss\"\n", 230 | ")\n", 231 | "\n", 232 | "acc_win = vis.line(\n", 233 | " X=np.column_stack((0, 0)),\n", 234 | " Y=np.column_stack((0, 0)),\n", 235 | " opts=dict(\n", 236 | " title='RKD Acc',\n", 237 | " xlabel='epoch',\n", 238 | " xtickmin=0,\n", 239 | " ylabel='accuracy',\n", 240 | " ytickmin=0,\n", 241 | " ytickmax=100,\n", 242 | " legend=['train_acc', 'test_acc']\n", 243 | " ),\n", 244 | " name=\"acc\"\n", 245 | ")\n", 246 | "\n", 247 | "\n", 248 | "# data\n", 249 | "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", 250 | "train_transform = transforms.Compose([\n", 251 | " transforms.RandomHorizontalFlip(),\n", 252 | " transforms.RandomCrop(32, 4),\n", 253 | " transforms.ToTensor(),\n", 254 | " normalize,\n", 255 | "])\n", 256 | "test_transform = transforms.Compose([transforms.ToTensor(), normalize])\n", 257 | "train_set = getattr(datasets, args.dataset)(root='../data', train=True, download=True, transform=train_transform)\n", 258 | "test_set = getattr(datasets, args.dataset)(root='../data', train=False, download=False, transform=test_transform)\n", 259 | "train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)\n", 260 | "test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False)\n", 261 | "# optim\n", 262 | "optimizer_sgd = optim.SGD(st_model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)\n", 263 | "lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer_sgd, gamma=0.1, milestones=[100, 150])\n", 264 | "\n", 265 | "\n", 266 | "# attention transfer loss, distance loss, angular loss\n", 267 | "at_criterion = AttentionTransfer().to(device)\n", 268 | "dist_criterion = RkdDistance().to(device)\n", 269 | "angle_criterion = RKdAngle().to(device)\n", 270 | "\n", 271 | "\n", 272 | "# train with teacher\n", 273 | "def train(epoch, model):\n", 274 | " print('Training:')\n", 275 | " # switch to train mode\n", 276 | " model.train()\n", 277 | " te_model.eval()\n", 278 | " batch_time = AverageMeter()\n", 279 | " data_time = AverageMeter()\n", 280 | " losses = AverageMeter()\n", 281 | " top1 = AverageMeter()\n", 282 | " \n", 283 | " end = time.time()\n", 284 | " for i, (input, target) in enumerate(train_loader):\n", 285 | "\n", 286 | " # measure data loading time\n", 287 | " data_time.update(time.time() - end)\n", 288 | "\n", 289 | " input, target = input.to(device), target.to(device)\n", 290 | " \n", 291 | " # compute outputs\n", 292 | " b1, b2, b3, pool, output = model(input)\n", 293 | " with torch.no_grad():\n", 294 | " t_b1, t_b2, t_b3, t_pool, t_output = te_model(input)\n", 295 | " \n", 296 | " optimizer_sgd.zero_grad()\n", 297 | " \n", 298 | " angle_loss = args.angle_ratio * angle_criterion(output, t_output)\n", 299 | " dist_loss = args.dist_ratio * dist_criterion(output, t_output)\n", 300 | " # attention loss\n", 301 | " at_loss = args.at_ratio * (at_criterion(b1, t_b1) + at_criterion(b2, t_b2) + at_criterion(b3, t_b3))\n", 302 | " entropy_loss = F.cross_entropy(output, target)\n", 303 | " loss = at_loss + angle_loss + dist_loss + entropy_loss\n", 304 | "\n", 305 | " loss.backward(retain_graph=True)\n", 306 | " optimizer_sgd.step()\n", 307 | "\n", 308 | " output = output.float()\n", 309 | " loss = loss.float()\n", 310 | " # measure accuracy and record loss\n", 311 | " train_acc = accuracy(output.data, target.data)[0]\n", 312 | " losses.update(loss.item(), input.size(0))\n", 313 | " top1.update(train_acc, input.size(0))\n", 314 | "\n", 315 | " # measure elapsed time\n", 316 | " batch_time.update(time.time() - end)\n", 317 | " end = time.time()\n", 318 | "\n", 319 | " if i % args.print_freq == 0:\n", 320 | " log_out('[{0}/{1}]\\t'\n", 321 | " 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n", 322 | " 'Data {data_time.val:.3f} ({data_time.avg:.3f})\\t'\n", 323 | " 'Loss {loss.val:.4f} ({loss.avg:.4f})\\t'\n", 324 | " 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(\n", 325 | " i, len(train_loader), batch_time=batch_time,\n", 326 | " data_time=data_time, loss=losses, top1=top1))\n", 327 | " return losses.avg, train_acc.cpu().numpy()\n", 328 | "\n", 329 | "\n", 330 | "def test(model):\n", 331 | " print('Testing:')\n", 332 | " # switch to evaluate mode\n", 333 | " model.eval()\n", 334 | " batch_time = AverageMeter()\n", 335 | " losses = AverageMeter()\n", 336 | " top1 = AverageMeter()\n", 337 | "\n", 338 | " end = time.time()\n", 339 | " with torch.no_grad():\n", 340 | " for i, (input, target) in enumerate(test_loader):\n", 341 | " input, target = input.to(device), target.to(device)\n", 342 | "\n", 343 | " # compute output\n", 344 | " _,_,_,_,output = model(input)\n", 345 | " loss = F.cross_entropy(output, target)\n", 346 | "\n", 347 | " output = output.float()\n", 348 | " loss = loss.float()\n", 349 | "\n", 350 | " # measure accuracy and record loss\n", 351 | " test_acc = accuracy(output.data, target.data)[0]\n", 352 | " losses.update(loss.item(), input.size(0))\n", 353 | " top1.update(test_acc, input.size(0))\n", 354 | "\n", 355 | " # measure elapsed time\n", 356 | " batch_time.update(time.time() - end)\n", 357 | " end = time.time()\n", 358 | "\n", 359 | " if i % args.print_freq == 0:\n", 360 | " log_out('Test: [{0}/{1}]\\t'\n", 361 | " 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n", 362 | " 'Loss {loss.val:.4f} ({loss.avg:.4f})\\t'\n", 363 | " 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(\n", 364 | " i, len(test_loader), batch_time=batch_time, loss=losses,\n", 365 | " top1=top1))\n", 366 | "\n", 367 | " log_out(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))\n", 368 | "\n", 369 | " return losses.avg, test_acc.cpu().numpy(), top1.avg.cpu().numpy()\n", 370 | "\n", 371 | "\n", 372 | "print('StudentNet:\\n')\n", 373 | "print(st_model)\n", 374 | "st_model.apply(weights_init_normal)\n", 375 | "best_acc = 0\n", 376 | "for epoch in range(1, args.epochs + 1):\n", 377 | " log_out(\"\\n===> epoch: {}/{}\".format(epoch, args.epochs))\n", 378 | " log_out('current lr {:.5e}'.format(optimizer_sgd.param_groups[0]['lr']))\n", 379 | " lr_scheduler.step(epoch)\n", 380 | " train_loss, train_acc = train(epoch, st_model)\n", 381 | " # visaulize loss\n", 382 | " vis.line(np.array([train_loss]), np.array([epoch]), loss_win, update=\"append\")\n", 383 | " _, test_acc, top1 = test(st_model)\n", 384 | " vis.line(np.column_stack((train_acc, top1)), np.column_stack((epoch, epoch)), acc_win, update=\"append\")\n", 385 | " if top1 > best_acc:\n", 386 | " best_acc = top1\n", 387 | " \n", 388 | "# release GPU memory\n", 389 | "torch.cuda.empty_cache()\n", 390 | "log_out(\"BEST ACC: {:.3f}\".format(best_acc))\n", 391 | "log_out(\"--- {:.3f} mins ---\".format((time.time() - start_time)/60))\n" 392 | ] 393 | } 394 | ], 395 | "metadata": { 396 | "kernelspec": { 397 | "display_name": "Python [conda env:py37] *", 398 | "language": "python", 399 | "name": "conda-env-py37-py" 400 | }, 401 | "language_info": { 402 | "codemirror_mode": { 403 | "name": "ipython", 404 | "version": 3 405 | }, 406 | "file_extension": ".py", 407 | "mimetype": "text/x-python", 408 | "name": "python", 409 | "nbconvert_exporter": "python", 410 | "pygments_lexer": "ipython3", 411 | "version": "3.7.1" 412 | } 413 | }, 414 | "nbformat": 4, 415 | "nbformat_minor": 2 416 | } 417 | -------------------------------------------------------------------------------- /group_Hint_only.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "start_time": "2019-08-27T02:56:00.190Z" 9 | } 10 | }, 11 | "outputs": [ 12 | { 13 | "name": "stderr", 14 | "output_type": "stream", 15 | "text": [ 16 | "/home/data/yaliu/jupyterbooks/multi-KD/models/teacher/resnet20.py:36: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.\n", 17 | " init.kaiming_normal(m.weight)\n", 18 | "WARNING:root:Setting up a new session...\n" 19 | ] 20 | }, 21 | { 22 | "name": "stdout", 23 | "output_type": "stream", 24 | "text": [ 25 | "Files already downloaded and verified\n", 26 | "StudentNet:\n", 27 | "\n", 28 | "ResNet(\n", 29 | " (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 30 | " (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 31 | " (layer1): Sequential(\n", 32 | " (0): BasicBlock(\n", 33 | " (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 34 | " (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 35 | " (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 36 | " (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 37 | " (shortcut): Sequential()\n", 38 | " )\n", 39 | " (1): BasicBlock(\n", 40 | " (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 41 | " (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 42 | " (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 43 | " (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 44 | " (shortcut): Sequential()\n", 45 | " )\n", 46 | " (2): BasicBlock(\n", 47 | " (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 48 | " (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 49 | " (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 50 | " (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 51 | " (shortcut): Sequential()\n", 52 | " )\n", 53 | " )\n", 54 | " (layer2): Sequential(\n", 55 | " (0): BasicBlock(\n", 56 | " (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 57 | " (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 58 | " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 59 | " (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 60 | " (shortcut): LambdaLayer()\n", 61 | " )\n", 62 | " (1): BasicBlock(\n", 63 | " (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 64 | " (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 65 | " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 66 | " (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 67 | " (shortcut): Sequential()\n", 68 | " )\n", 69 | " (2): BasicBlock(\n", 70 | " (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 71 | " (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 72 | " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 73 | " (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 74 | " (shortcut): Sequential()\n", 75 | " )\n", 76 | " )\n", 77 | " (layer3): Sequential(\n", 78 | " (0): BasicBlock(\n", 79 | " (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 80 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 81 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 82 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 83 | " (shortcut): LambdaLayer()\n", 84 | " )\n", 85 | " (1): BasicBlock(\n", 86 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 87 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 88 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 89 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 90 | " (shortcut): Sequential()\n", 91 | " )\n", 92 | " (2): BasicBlock(\n", 93 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 94 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 95 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 96 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 97 | " (shortcut): Sequential()\n", 98 | " )\n", 99 | " )\n", 100 | " (linear): Linear(in_features=64, out_features=10, bias=True)\n", 101 | ")\n", 102 | "\n", 103 | "===> epoch: 1/200\n", 104 | "current lr 1.00000e-01\n", 105 | "Training:\n", 106 | "[0/391]\tTime 0.121 (0.121)\tData 0.027 (0.027)\tLoss 3.7064 (3.7064)\tPrec@1 8.594 (8.594)\n" 107 | ] 108 | }, 109 | { 110 | "name": "stderr", 111 | "output_type": "stream", 112 | "text": [ 113 | "/home/yaliu/Dev/anaconda3/envs/py37/lib/python3.7/site-packages/ipykernel_launcher.py:192: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", 114 | "/home/yaliu/Dev/anaconda3/envs/py37/lib/python3.7/site-packages/ipykernel_launcher.py:154: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n", 115 | "/home/yaliu/Dev/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/functional.py:1992: UserWarning: reduction: 'mean' divides the total loss by both the batch size and the support size.'batchmean' divides only by the batch size, and aligns with the KL div math definition.'mean' will be changed to behave the same as 'batchmean' in the next major release.\n", 116 | " warnings.warn(\"reduction: 'mean' divides the total loss by both the batch size and the support size.\"\n" 117 | ] 118 | }, 119 | { 120 | "name": "stdout", 121 | "output_type": "stream", 122 | "text": [ 123 | "[10/391]\tTime 0.119 (0.132)\tData 0.022 (0.024)\tLoss 2.5631 (3.0669)\tPrec@1 17.969 (13.281)\n", 124 | "[20/391]\tTime 0.121 (0.132)\tData 0.021 (0.024)\tLoss 2.4680 (2.8043)\tPrec@1 17.188 (16.257)\n", 125 | "[30/391]\tTime 0.115 (0.132)\tData 0.021 (0.024)\tLoss 2.4514 (2.6880)\tPrec@1 19.531 (18.170)\n" 126 | ] 127 | } 128 | ], 129 | "source": [ 130 | "from __future__ import print_function\n", 131 | "import os\n", 132 | "import time\n", 133 | "import logging\n", 134 | "import argparse\n", 135 | "import numpy as np\n", 136 | "from visdom import Visdom\n", 137 | "from PIL import Image\n", 138 | "import torch\n", 139 | "import torch.nn as nn\n", 140 | "import torch.nn.functional as F\n", 141 | "import torch.optim as optim\n", 142 | "from torch.autograd import Variable\n", 143 | "from torch.utils.data import DataLoader\n", 144 | "from torchvision import datasets, transforms\n", 145 | "from utils import *\n", 146 | "from metric.loss import FitNet\n", 147 | "\n", 148 | "# Teacher models:\n", 149 | "# VGG11/VGG13/VGG16/VGG19, GoogLeNet, AlxNet, ResNet18, ResNet34, \n", 150 | "# ResNet50, ResNet101, ResNet152, ResNeXt29_2x64d, ResNeXt29_4x64d, \n", 151 | "# ResNeXt29_8x64d, ResNeXt29_32x64d, PreActResNet18, PreActResNet34, \n", 152 | "# PreActhttps://www.bing.com/?mkt=zh-CNResNet50, PreActResNet101, PreActResNet152, \n", 153 | "# DenseNet121, DenseNet161, DenseNet169, DenseNet201, \n", 154 | "import models\n", 155 | "\n", 156 | "# Student models:\n", 157 | "# myNet, LeNet, FitNet\n", 158 | "\n", 159 | "start_time = time.time()\n", 160 | "# os.makedirs('./checkpoint', exist_ok=True)\n", 161 | "\n", 162 | "# Training settings\n", 163 | "parser = argparse.ArgumentParser(description='PyTorch ada. FitNet')\n", 164 | "\n", 165 | "parser.add_argument('--dataset',\n", 166 | " choices=['CIFAR10',\n", 167 | " 'CIFAR100'\n", 168 | " ],\n", 169 | " default='CIFAR10')\n", 170 | "parser.add_argument('--teachers',\n", 171 | " choices=['ResNet32',\n", 172 | " 'ResNet50',\n", 173 | " 'ResNet56',\n", 174 | " 'ResNet110',\n", 175 | " 'DenseNet121'\n", 176 | " ],\n", 177 | " default=['ResNet32', 'ResNet56', 'ResNet110'],\n", 178 | " nargs='+')\n", 179 | "parser.add_argument('--student',\n", 180 | " choices=['ResNet8',\n", 181 | " 'ResNet15',\n", 182 | " 'ResNet16',\n", 183 | " 'ResNet20',\n", 184 | " 'myNet'\n", 185 | " ],\n", 186 | " default='ResNet20')\n", 187 | "\n", 188 | "parser.add_argument('--kd_ratio', default=0.7, type=float)\n", 189 | "parser.add_argument('--n_class', type=int, default=10, metavar='N', help='num of classes')\n", 190 | "parser.add_argument('--T', type=float, default=20.0, metavar='Temputure', help='Temputure for distillation')\n", 191 | "parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size for training')\n", 192 | "parser.add_argument('--test_batch_size', type=int, default=128, metavar='N', help='input test batch size for training')\n", 193 | "parser.add_argument('--epochs', type=int, default=20, metavar='N', help='number of epochs to train (default: 20)')\n", 194 | "parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 0.01)')\n", 195 | "parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.5)')\n", 196 | "parser.add_argument('--device', default='cuda:0', type=str, help='device: cuda or cpu')\n", 197 | "parser.add_argument('--print_freq', type=int, default=10, metavar='N', help='how many batches to wait before logging training status')\n", 198 | "\n", 199 | "config = ['--epochs', '200', '--T', '5.0', '--device', 'cuda:1']\n", 200 | "args = parser.parse_args(config)\n", 201 | "\n", 202 | "device = args.device if torch.cuda.is_available() else 'cpu'\n", 203 | "load_dir = './checkpoint/' + args.dataset + '/'\n", 204 | "\n", 205 | "# teachers model\n", 206 | "teacher_models = []\n", 207 | "for te in args.teachers:\n", 208 | " te_model = getattr(models, te)(num_classes=args.n_class)\n", 209 | " te_model.load_state_dict(torch.load(load_dir + te_model.model_name + '.pth'))\n", 210 | " te_model.to(device)\n", 211 | " teacher_models.append(te_model)\n", 212 | "\n", 213 | "st_model = getattr(models, args.student)(num_classes=args.n_class) # args.student()\n", 214 | "st_model.to(device)\n", 215 | "\n", 216 | "# logging\n", 217 | "logfile = load_dir + 'groupHint_' + st_model.model_name + '.log'\n", 218 | "if os.path.exists(logfile):\n", 219 | " os.remove(logfile)\n", 220 | "def log_out(info):\n", 221 | " f = open(logfile, mode='a')\n", 222 | " f.write(info)\n", 223 | " f.write('\\n')\n", 224 | " f.close()\n", 225 | " print(info)\n", 226 | " \n", 227 | "# visualizer\n", 228 | "vis = Visdom(env='distill')\n", 229 | "loss_win = vis.line(\n", 230 | " X=np.array([0]),\n", 231 | " Y=np.array([0]),\n", 232 | " opts=dict(\n", 233 | " title='group Hint loss',\n", 234 | " xtickmin=0,\n", 235 | "# xtickmax=1,\n", 236 | "# xtickstep=5,\n", 237 | " ytickmin=0,\n", 238 | "# ytickmax=1,\n", 239 | " ytickstep=0.5,\n", 240 | "# markers=True,\n", 241 | "# markersymbol='dot',\n", 242 | "# markersize=5,\n", 243 | " ),\n", 244 | " name=\"loss\"\n", 245 | ")\n", 246 | "\n", 247 | "acc_win = vis.line(\n", 248 | " X=np.column_stack((0, 0)),\n", 249 | " Y=np.column_stack((0, 0)),\n", 250 | " opts=dict(\n", 251 | " title='group Hint Acc',\n", 252 | " xtickmin=0,\n", 253 | "# xtickstep=5,\n", 254 | " ytickmin=0,\n", 255 | " ytickmax=100,\n", 256 | "# markers=True,\n", 257 | "# markersymbol='dot',\n", 258 | "# markersize=5,\n", 259 | " legend=['train_acc', 'test_acc']\n", 260 | " ),\n", 261 | " name=\"acc\"\n", 262 | ")\n", 263 | "\n", 264 | "# data\n", 265 | "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", 266 | "train_transform = transforms.Compose([\n", 267 | " transforms.RandomHorizontalFlip(),\n", 268 | " transforms.RandomCrop(32, 4),\n", 269 | " transforms.ToTensor(),\n", 270 | " normalize,\n", 271 | "])\n", 272 | "test_transform = transforms.Compose([transforms.ToTensor(), normalize])\n", 273 | "train_set = getattr(datasets, args.dataset)(root='../data', train=True, download=True, transform=train_transform)\n", 274 | "test_set = getattr(datasets, args.dataset)(root='../data', train=False, download=False, transform=test_transform)\n", 275 | "train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)\n", 276 | "test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False)\n", 277 | "# optim\n", 278 | "optimizer_sgd = optim.SGD(st_model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)\n", 279 | "lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer_sgd, gamma=0.1, milestones=[100, 150])\n", 280 | "\n", 281 | "# loss\n", 282 | "def kd_loss(y, labels, weighted_logits, T=10.0, alpha=0.7):\n", 283 | " ls = nn.KLDivLoss()(F.log_softmax(y/T), weighted_logits) * (T*T * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)\n", 284 | " return ls\n", 285 | "\n", 286 | "fitnet_criterion = [FitNet(32, 64), FitNet(64, 64),FitNet(64, 64)]\n", 287 | "[f.to(device) for f in fitnet_criterion]\n", 288 | "\n", 289 | "\n", 290 | "# train with multi-teacher\n", 291 | "def train(epoch, model):\n", 292 | " print('Training:')\n", 293 | " # switch to train mode\n", 294 | " model.train()\n", 295 | " batch_time = AverageMeter()\n", 296 | " data_time = AverageMeter()\n", 297 | " losses = AverageMeter()\n", 298 | " top1 = AverageMeter()\n", 299 | " \n", 300 | " end = time.time()\n", 301 | " for i, (input, target) in enumerate(train_loader):\n", 302 | "\n", 303 | " # measure data loading time\n", 304 | " data_time.update(time.time() - end)\n", 305 | "\n", 306 | " input, target = input.to(device), target.to(device)\n", 307 | " \n", 308 | " # compute outputs\n", 309 | " b1, b2, b3, pool, output = model(input)\n", 310 | " st_maps = [b1, b2, b3, pool]\n", 311 | " \n", 312 | " te_scores_list = []\n", 313 | " hint_maps = []\n", 314 | " fit_loss = 0\n", 315 | " for j,te in enumerate(teacher_models):\n", 316 | " te.eval()\n", 317 | " with torch.no_grad():\n", 318 | " t_b1, t_b2, t_b3, t_pool, t_output = te(input)\n", 319 | " \n", 320 | " hint_maps.append(t_pool)\n", 321 | " t_output = F.softmax(t_output/args.T)\n", 322 | " te_scores_list.append(t_output)\n", 323 | " te_scores_Tensor = torch.stack(te_scores_list, dim=1) # size: [128, 3, 10]\n", 324 | " \n", 325 | " optimizer_sgd.zero_grad()\n", 326 | " \n", 327 | " # compute gradient and do SGD step\n", 328 | " KD_loss = kd_loss(output, target, t_output, T=args.T, alpha=args.kd_ratio)\n", 329 | " for j in range(len(teacher_models)-1):\n", 330 | " fit_loss += fitnet_criterion[j](st_maps[j+1], hint_maps[j])\n", 331 | " \n", 332 | " loss = KD_loss + fit_loss\n", 333 | "\n", 334 | " loss.backward(retain_graph=True)\n", 335 | " optimizer_sgd.step()\n", 336 | "\n", 337 | " output = output.float()\n", 338 | " loss = loss.float()\n", 339 | " # measure accuracy and record loss\n", 340 | " train_acc = accuracy(output.data, target.data)[0]\n", 341 | " losses.update(loss.item(), input.size(0))\n", 342 | " top1.update(train_acc, input.size(0))\n", 343 | "\n", 344 | " # measure elapsed time\n", 345 | " batch_time.update(time.time() - end)\n", 346 | " end = time.time()\n", 347 | "\n", 348 | " if i % args.print_freq == 0:\n", 349 | " log_out('[{0}/{1}]\\t'\n", 350 | " 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n", 351 | " 'Data {data_time.val:.3f} ({data_time.avg:.3f})\\t'\n", 352 | " 'Loss {loss.val:.4f} ({loss.avg:.4f})\\t'\n", 353 | " 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(\n", 354 | " i, len(train_loader), batch_time=batch_time,\n", 355 | " data_time=data_time, loss=losses, top1=top1))\n", 356 | " return losses.avg, train_acc.cpu().numpy()\n", 357 | "\n", 358 | "\n", 359 | "def test(model):\n", 360 | " print('Testing:')\n", 361 | " # switch to evaluate mode\n", 362 | " model.eval()\n", 363 | " batch_time = AverageMeter()\n", 364 | " losses = AverageMeter()\n", 365 | " top1 = AverageMeter()\n", 366 | "\n", 367 | " end = time.time()\n", 368 | " with torch.no_grad():\n", 369 | " for i, (input, target) in enumerate(test_loader):\n", 370 | " input, target = input.to(device), target.to(device)\n", 371 | "\n", 372 | " # compute output\n", 373 | " _,_,_,_,output = model(input)\n", 374 | " loss = F.cross_entropy(output, target)\n", 375 | "\n", 376 | " output = output.float()\n", 377 | " loss = loss.float()\n", 378 | "\n", 379 | " # measure accuracy and record loss\n", 380 | " test_acc = accuracy(output.data, target.data)[0]\n", 381 | " losses.update(loss.item(), input.size(0))\n", 382 | " top1.update(test_acc, input.size(0))\n", 383 | "\n", 384 | " # measure elapsed time\n", 385 | " batch_time.update(time.time() - end)\n", 386 | " end = time.time()\n", 387 | "\n", 388 | " if i % args.print_freq == 0:\n", 389 | " log_out('Test: [{0}/{1}]\\t'\n", 390 | " 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n", 391 | " 'Loss {loss.val:.4f} ({loss.avg:.4f})\\t'\n", 392 | " 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(\n", 393 | " i, len(test_loader), batch_time=batch_time, loss=losses,\n", 394 | " top1=top1))\n", 395 | "\n", 396 | " log_out(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))\n", 397 | "\n", 398 | " return losses.avg, test_acc.cpu().numpy(), top1.avg.cpu().numpy()\n", 399 | "\n", 400 | "# \"\"\"\n", 401 | "print('StudentNet:\\n')\n", 402 | "print(st_model)\n", 403 | "st_model.apply(weights_init_normal)\n", 404 | "best_acc = 0\n", 405 | "for epoch in range(1, args.epochs + 1):\n", 406 | " log_out(\"\\n===> epoch: {}/{}\".format(epoch, args.epochs))\n", 407 | " log_out('current lr {:.5e}'.format(optimizer_sgd.param_groups[0]['lr']))\n", 408 | " lr_scheduler.step(epoch)\n", 409 | " train_loss, train_acc = train(epoch, st_model)\n", 410 | " # visaulize loss\n", 411 | " vis.line(np.array([train_loss]), np.array([epoch]), loss_win, update=\"append\")\n", 412 | " _, test_acc, top1 = test(st_model)\n", 413 | " vis.line(np.column_stack((train_acc, top1)), np.column_stack((epoch, epoch)), acc_win, update=\"append\")\n", 414 | " if top1 > best_acc:\n", 415 | " best_acc = top1\n", 416 | " \n", 417 | "# release GPU memory\n", 418 | "torch.cuda.empty_cache()\n", 419 | "log_out(\"BEST ACC: {:.3f}\".format(best_acc))\n", 420 | "log_out(\"--- {:.3f} mins ---\".format((time.time() - start_time)/60))\n", 421 | "# \"\"\"" 422 | ] 423 | } 424 | ], 425 | "metadata": { 426 | "kernelspec": { 427 | "display_name": "Python [conda env:py37]", 428 | "language": "python", 429 | "name": "conda-env-py37-py" 430 | }, 431 | "language_info": { 432 | "codemirror_mode": { 433 | "name": "ipython", 434 | "version": 3 435 | }, 436 | "file_extension": ".py", 437 | "mimetype": "text/x-python", 438 | "name": "python", 439 | "nbconvert_exporter": "python", 440 | "pygments_lexer": "ipython3", 441 | "version": "3.7.1" 442 | } 443 | }, 444 | "nbformat": 4, 445 | "nbformat_minor": 2 446 | } 447 | -------------------------------------------------------------------------------- /DML_3S.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 8, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2019-08-03T07:35:22.228272Z", 9 | "start_time": "2019-08-03T07:32:19.281366Z" 10 | } 11 | }, 12 | "outputs": [ 13 | { 14 | "name": "stderr", 15 | "output_type": "stream", 16 | "text": [ 17 | "WARNING:root:Setting up a new session...\n" 18 | ] 19 | }, 20 | { 21 | "name": "stdout", 22 | "output_type": "stream", 23 | "text": [ 24 | "Files already downloaded and verified\n", 25 | "*-----------------DML----------------*\n", 26 | "\n", 27 | "===> epoch: 1/200\n", 28 | "Training:\n" 29 | ] 30 | }, 31 | { 32 | "name": "stderr", 33 | "output_type": "stream", 34 | "text": [ 35 | "/home/yaliu/Dev/anaconda3/envs/py37/lib/python3.7/site-packages/ipykernel_launcher.py:160: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n" 36 | ] 37 | }, 38 | { 39 | "name": "stdout", 40 | "output_type": "stream", 41 | "text": [ 42 | "[0/391]\tTime 0.343 (0.343)\tData 0.019 (0.019)\tLoss 6.4467 (4.5754)\tPrec@1 (7.031)\n", 43 | "[0/391]\tTime 0.343 (0.343)\tData 0.019 (0.019)\tLoss 6.4467 (4.5754)\tPrec@1 (7.031)\n", 44 | "[0/391]\tTime 0.343 (0.343)\tData 0.019 (0.019)\tLoss 6.4467 (4.5754)\tPrec@1 (7.031)\n", 45 | "[40/391]\tTime 0.317 (0.319)\tData 0.018 (0.018)\tLoss 2.3183 (2.7125)\tPrec@1 (13.993)\n", 46 | "[40/391]\tTime 0.317 (0.319)\tData 0.018 (0.018)\tLoss 2.3183 (2.7125)\tPrec@1 (13.993)\n", 47 | "[40/391]\tTime 0.317 (0.319)\tData 0.018 (0.018)\tLoss 2.3183 (2.7125)\tPrec@1 (13.993)\n", 48 | "[80/391]\tTime 0.328 (0.319)\tData 0.018 (0.018)\tLoss 1.9819 (2.3065)\tPrec@1 (16.917)\n", 49 | "[80/391]\tTime 0.328 (0.319)\tData 0.018 (0.018)\tLoss 1.9819 (2.3065)\tPrec@1 (16.917)\n", 50 | "[80/391]\tTime 0.328 (0.319)\tData 0.018 (0.018)\tLoss 1.9819 (2.3065)\tPrec@1 (16.917)\n", 51 | "[120/391]\tTime 0.328 (0.320)\tData 0.018 (0.018)\tLoss 1.9261 (2.1333)\tPrec@1 (19.551)\n", 52 | "[120/391]\tTime 0.328 (0.320)\tData 0.018 (0.018)\tLoss 1.9261 (2.1333)\tPrec@1 (19.551)\n", 53 | "[120/391]\tTime 0.328 (0.320)\tData 0.018 (0.018)\tLoss 1.9261 (2.1333)\tPrec@1 (19.551)\n", 54 | "[160/391]\tTime 0.327 (0.323)\tData 0.018 (0.018)\tLoss 1.8759 (2.0296)\tPrec@1 (21.610)\n", 55 | "[160/391]\tTime 0.327 (0.323)\tData 0.018 (0.018)\tLoss 1.8759 (2.0296)\tPrec@1 (21.610)\n", 56 | "[160/391]\tTime 0.327 (0.323)\tData 0.018 (0.018)\tLoss 1.8759 (2.0296)\tPrec@1 (21.610)\n", 57 | "[200/391]\tTime 0.334 (0.325)\tData 0.018 (0.018)\tLoss 1.7957 (1.9522)\tPrec@1 (23.312)\n", 58 | "[200/391]\tTime 0.334 (0.325)\tData 0.018 (0.018)\tLoss 1.7957 (1.9522)\tPrec@1 (23.312)\n", 59 | "[200/391]\tTime 0.334 (0.325)\tData 0.018 (0.018)\tLoss 1.7957 (1.9522)\tPrec@1 (23.312)\n", 60 | "[240/391]\tTime 0.352 (0.329)\tData 0.019 (0.018)\tLoss 1.7561 (1.8937)\tPrec@1 (24.799)\n", 61 | "[240/391]\tTime 0.352 (0.329)\tData 0.019 (0.018)\tLoss 1.7561 (1.8937)\tPrec@1 (24.799)\n", 62 | "[240/391]\tTime 0.352 (0.329)\tData 0.019 (0.018)\tLoss 1.7561 (1.8937)\tPrec@1 (24.799)\n", 63 | "[280/391]\tTime 0.379 (0.332)\tData 0.018 (0.018)\tLoss 1.6914 (1.8409)\tPrec@1 (26.356)\n", 64 | "[280/391]\tTime 0.379 (0.332)\tData 0.018 (0.018)\tLoss 1.6914 (1.8409)\tPrec@1 (26.356)\n", 65 | "[280/391]\tTime 0.379 (0.332)\tData 0.018 (0.018)\tLoss 1.6914 (1.8409)\tPrec@1 (26.356)\n", 66 | "[320/391]\tTime 0.336 (0.333)\tData 0.018 (0.018)\tLoss 1.6142 (1.7941)\tPrec@1 (27.836)\n", 67 | "[320/391]\tTime 0.336 (0.333)\tData 0.018 (0.018)\tLoss 1.6142 (1.7941)\tPrec@1 (27.836)\n", 68 | "[320/391]\tTime 0.336 (0.333)\tData 0.018 (0.018)\tLoss 1.6142 (1.7941)\tPrec@1 (27.836)\n", 69 | "[360/391]\tTime 0.333 (0.335)\tData 0.018 (0.018)\tLoss 1.6891 (1.7526)\tPrec@1 (29.186)\n", 70 | "[360/391]\tTime 0.333 (0.335)\tData 0.018 (0.018)\tLoss 1.6891 (1.7526)\tPrec@1 (29.186)\n", 71 | "[360/391]\tTime 0.333 (0.335)\tData 0.018 (0.018)\tLoss 1.6891 (1.7526)\tPrec@1 (29.186)\n", 72 | "Testing:\n", 73 | "Test: [0/79]\tTime 0.019 (0.019)\tLoss 1.4136 (1.4136)\tPrec@1 46.094 (46.094)\n", 74 | "Test: [40/79]\tTime 0.019 (0.018)\tLoss 1.2804 (1.3270)\tPrec@1 54.688 (51.086)\n", 75 | " * ResNet20 Prec@1 50.810\n", 76 | "Testing:\n", 77 | "Test: [0/79]\tTime 0.026 (0.026)\tLoss 1.7197 (1.7197)\tPrec@1 35.938 (35.938)\n", 78 | "Test: [40/79]\tTime 0.025 (0.025)\tLoss 1.7212 (1.7040)\tPrec@1 34.375 (37.024)\n", 79 | " * ResNet56 Prec@1 36.260\n", 80 | "Testing:\n", 81 | "Test: [0/79]\tTime 0.035 (0.035)\tLoss 1.7451 (1.7451)\tPrec@1 39.844 (39.844)\n", 82 | "Test: [40/79]\tTime 0.033 (0.034)\tLoss 1.7589 (1.7887)\tPrec@1 31.250 (32.393)\n", 83 | " * ResNet110 Prec@1 32.120\n", 84 | "\n", 85 | "===> epoch: 2/200\n", 86 | "Training:\n", 87 | "[0/391]\tTime 0.354 (0.354)\tData 0.033 (0.033)\tLoss 1.7618 (1.5796)\tPrec@1 (37.760)\n", 88 | "[0/391]\tTime 0.354 (0.354)\tData 0.033 (0.033)\tLoss 1.7618 (1.5796)\tPrec@1 (37.760)\n", 89 | "[0/391]\tTime 0.354 (0.354)\tData 0.033 (0.033)\tLoss 1.7618 (1.5796)\tPrec@1 (37.760)\n", 90 | "[40/391]\tTime 0.334 (0.343)\tData 0.018 (0.019)\tLoss nan (nan)\tPrec@1 (7.539)\n", 91 | "[40/391]\tTime 0.334 (0.343)\tData 0.018 (0.019)\tLoss nan (nan)\tPrec@1 (7.539)\n", 92 | "[40/391]\tTime 0.334 (0.343)\tData 0.018 (0.019)\tLoss nan (nan)\tPrec@1 (7.539)\n", 93 | "[80/391]\tTime 0.357 (0.346)\tData 0.018 (0.019)\tLoss nan (nan)\tPrec@1 (7.115)\n", 94 | "[80/391]\tTime 0.357 (0.346)\tData 0.018 (0.019)\tLoss nan (nan)\tPrec@1 (7.115)\n", 95 | "[80/391]\tTime 0.357 (0.346)\tData 0.018 (0.019)\tLoss nan (nan)\tPrec@1 (7.115)\n", 96 | "[120/391]\tTime 0.322 (0.348)\tData 0.018 (0.018)\tLoss nan (nan)\tPrec@1 (6.891)\n", 97 | "[120/391]\tTime 0.322 (0.348)\tData 0.018 (0.018)\tLoss nan (nan)\tPrec@1 (6.891)\n", 98 | "[120/391]\tTime 0.322 (0.348)\tData 0.018 (0.018)\tLoss nan (nan)\tPrec@1 (6.891)\n" 99 | ] 100 | }, 101 | { 102 | "ename": "KeyboardInterrupt", 103 | "evalue": "", 104 | "output_type": "error", 105 | "traceback": [ 106 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 107 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 108 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 251\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mj\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mK\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 252\u001b[0m \u001b[0mlr_scheduler_list\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 253\u001b[0;31m \u001b[0mtrain_loss_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtop1_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnets_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 254\u001b[0m \u001b[0;31m# visaulize loss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 255\u001b[0m \u001b[0mvis\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mline\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loss_list\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mepoch\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mK\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss_win\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mupdate\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"append\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 109 | "\u001b[0;32m\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(epoch, nets_list)\u001b[0m\n\u001b[1;32m 174\u001b[0m \u001b[0mloss_j\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcross_entropy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput_list\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 175\u001b[0m \u001b[0mloss_j\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mretain_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 176\u001b[0;31m \u001b[0moptimizers_list\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mj\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 177\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[0;31m# measure accuracy and record loss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 110 | "\u001b[0;32m~/Dev/anaconda3/envs/py37/lib/python3.7/site-packages/torch/optim/sgd.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 99\u001b[0m \u001b[0mbuf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparam_state\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'momentum_buffer'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 100\u001b[0;31m \u001b[0mbuf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmul_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmomentum\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mdampening\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md_p\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 101\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mnesterov\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0md_p\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0md_p\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmomentum\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 111 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 112 | ] 113 | } 114 | ], 115 | "source": [ 116 | "from __future__ import print_function\n", 117 | "import os\n", 118 | "import time\n", 119 | "import logging\n", 120 | "import argparse\n", 121 | "import numpy as np\n", 122 | "from visdom import Visdom\n", 123 | "import torch\n", 124 | "import torch.nn as nn\n", 125 | "import torch.nn.functional as F\n", 126 | "import torch.optim as optim\n", 127 | "from torch.autograd import Variable\n", 128 | "from torch.utils.data import DataLoader\n", 129 | "from torchvision import datasets, transforms\n", 130 | "from utils import *\n", 131 | "\n", 132 | "# Teacher models:\n", 133 | "# VGG11/VGG13/VGG16/VGG19, GoogLeNet, AlxNet, ResNet18, ResNet34, \n", 134 | "# ResNet50, ResNet101, ResNet152, ResNeXt29_2x64d, ResNeXt29_4x64d, \n", 135 | "# ResNeXt29_8x64d, ResNeXt29_32x64d, PreActResNet18, PreActResNet34, \n", 136 | "# PreActResNet50, PreActResNet101, PreActResNet152, \n", 137 | "# DenseNet121, DenseNet161, DenseNet169, DenseNet201, \n", 138 | "import models\n", 139 | "\n", 140 | "# Student models:\n", 141 | "# myNet, LeNet, FitNet\n", 142 | "\n", 143 | "start_time = time.time()\n", 144 | "# os.makedirs('./checkpoint', exist_ok=True)\n", 145 | "\n", 146 | "# Training settings\n", 147 | "parser = argparse.ArgumentParser(description='PyTorch DML 2S')\n", 148 | "\n", 149 | "parser.add_argument('--dataset',\n", 150 | " choices=['CIFAR10',\n", 151 | " 'CIFAR100'\n", 152 | " ],\n", 153 | " default='CIFAR10')\n", 154 | "parser.add_argument('--nets',\n", 155 | " choices=['ResNet32',\n", 156 | " 'ResNet50',\n", 157 | " 'ResNet56',\n", 158 | " 'ResNet110'\n", 159 | " ],\n", 160 | " default=['ResNet20', 'ResNet56', 'ResNet110'],\n", 161 | " nargs='+')\n", 162 | "\n", 163 | "parser.add_argument('--n_class', type=int, default=10, metavar='N', help='num of classes')\n", 164 | "parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size for training')\n", 165 | "parser.add_argument('--test_batch_size', type=int, default=128, metavar='N', help='input test batch size for training')\n", 166 | "parser.add_argument('--epochs', type=int, default=20, metavar='N', help='number of epochs to train (default: 20)')\n", 167 | "parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 0.01)')\n", 168 | "parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.5)')\n", 169 | "parser.add_argument('--device', default='cuda:1', type=str, help='device: cuda or cpu')\n", 170 | "parser.add_argument('--print_freq', type=int, default=40, metavar='N', help='how many batches to wait before logging training status')\n", 171 | "\n", 172 | "config = ['--epochs', '200', '--device', 'cuda:1']\n", 173 | "args = parser.parse_args(config)\n", 174 | "\n", 175 | "device = args.device if torch.cuda.is_available() else 'cpu'\n", 176 | "save_dir = './checkpoint/' + args.dataset + '/'\n", 177 | "\n", 178 | "# models\n", 179 | "nets_list = []\n", 180 | "for m in args.nets:\n", 181 | " net = getattr(models, m)()\n", 182 | " net.to(device)\n", 183 | " net.train() # train mode\n", 184 | " nets_list.append(net)\n", 185 | "\n", 186 | "K = len(nets_list)\n", 187 | " \n", 188 | "# logging\n", 189 | "logfile = save_dir + 'DML_3S_.log'\n", 190 | "if os.path.exists(logfile):\n", 191 | " os.remove(logfile)\n", 192 | "def log_out(info):\n", 193 | " f = open(logfile, mode='a')\n", 194 | " f.write(info)\n", 195 | " f.write('\\n')\n", 196 | " f.close()\n", 197 | " print(info)\n", 198 | " \n", 199 | "# visualizer\n", 200 | "vis = Visdom(env='distill')\n", 201 | "loss_win = vis.line(\n", 202 | " X=np.array([0]*K),\n", 203 | " Y=np.array([0]*K),\n", 204 | " opts=dict(\n", 205 | " title='DML_3S Loss',\n", 206 | " xlabel='epoch',\n", 207 | " xtickmin=0,\n", 208 | " ylabel='loss',\n", 209 | " ytickmin=0,\n", 210 | " ytickstep=0.5\n", 211 | " ),\n", 212 | " name=\"loss\"\n", 213 | ")\n", 214 | "\n", 215 | "acc_win = vis.line(\n", 216 | " X=np.array([0]*K),\n", 217 | " Y=np.array([0]*K),\n", 218 | " opts=dict(\n", 219 | " title='DML_3S Acc',\n", 220 | " xlabel='epoch',\n", 221 | " xtickmin=0,\n", 222 | " ylabel='accuracy',\n", 223 | " ytickmin=0,\n", 224 | " ytickmax=100\n", 225 | " ),\n", 226 | " name=\"acc\"\n", 227 | ")\n", 228 | "\n", 229 | "\n", 230 | "# data\n", 231 | "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", 232 | "train_transform = transforms.Compose([\n", 233 | " transforms.RandomHorizontalFlip(),\n", 234 | " transforms.RandomCrop(32, 4),\n", 235 | " transforms.ToTensor(),\n", 236 | " normalize,\n", 237 | "])\n", 238 | "test_transform = transforms.Compose([transforms.ToTensor(), normalize])\n", 239 | "train_set = datasets.CIFAR10(root='../data', train=True, download=True, transform=train_transform)\n", 240 | "test_set = datasets.CIFAR10(root='../data', train=False, download=False, transform=test_transform)\n", 241 | "train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)\n", 242 | "test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False)\n", 243 | "\n", 244 | "# optimizer = optim.SGD(st_model.parameters(), lr=args.lr, momentum=args.momentum)\n", 245 | "optimizers_list = []\n", 246 | "lr_scheduler_list = []\n", 247 | "for m in nets_list:\n", 248 | " optimizer_m = optim.SGD(m.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)\n", 249 | " lr_scheduler_m = optim.lr_scheduler.MultiStepLR(optimizer_m, milestones=[100, 150])\n", 250 | " optimizers_list.append(optimizer_m)\n", 251 | " lr_scheduler_list.append(lr_scheduler_m)\n", 252 | "\n", 253 | " \n", 254 | "# train with multi-teacher\n", 255 | "def train(epoch, nets_list):\n", 256 | " print('Training:')\n", 257 | " K = len(nets_list)\n", 258 | "\n", 259 | " batch_time = AverageMeter()\n", 260 | " data_time = AverageMeter()\n", 261 | " losses_list = [AverageMeter()] * K\n", 262 | " top1_list = [AverageMeter()] * K\n", 263 | " \n", 264 | " end = time.time()\n", 265 | " for i, (input, target) in enumerate(train_loader):\n", 266 | "\n", 267 | " # measure data loading time\n", 268 | " data_time.update(time.time() - end)\n", 269 | "\n", 270 | " input, target = input.to(device), target.to(device)\n", 271 | " \n", 272 | " # compute outputs\n", 273 | " output_list = []\n", 274 | " logits_list = []\n", 275 | " for net in nets_list:\n", 276 | " _,_,_,_, output_m = net(input)\n", 277 | " logits_m = F.softmax(output_m)\n", 278 | " output_list.append(output_m)\n", 279 | " logits_list.append(logits_m)\n", 280 | " \n", 281 | " for j in range(K):\n", 282 | " loss_j = 0\n", 283 | " \n", 284 | " optimizers_list[j].zero_grad()\n", 285 | " for h in range(K):\n", 286 | " if h != j:\n", 287 | " loss_j += nn.KLDivLoss()(logits_list[h], logits_list[j]) \n", 288 | " loss_j /= K - 1\n", 289 | " loss_j += F.cross_entropy(output_list[j], target)\n", 290 | " loss_j.backward() # retain_graph=True\n", 291 | " optimizers_list[j].step()\n", 292 | " \n", 293 | " # measure accuracy and record loss\n", 294 | " netj_acc = accuracy(output_list[j], target)[0]\n", 295 | " losses_list[j].update(loss_j.item(), input.size(0))\n", 296 | " top1_list[j].update(netj_acc, input.size(0))\n", 297 | " \n", 298 | " # measure elapsed time\n", 299 | " batch_time.update(time.time() - end)\n", 300 | " end = time.time()\n", 301 | "\n", 302 | " if i % args.print_freq == 0:\n", 303 | " for j in range(K):\n", 304 | " log_out('[{0}/{1}]\\t'\n", 305 | " 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n", 306 | " 'Data {data_time.val:.3f} ({data_time.avg:.3f})\\t'\n", 307 | " 'Loss {loss.val:.4f} ({loss.avg:.4f})\\t'\n", 308 | " 'Prec@1 ({top1_1.avg:.3f})'.format(\n", 309 | " i, len(train_loader), batch_time=batch_time,\n", 310 | " data_time=data_time, loss=losses_list[j], top1_1=top1_list[j]))\n", 311 | " \n", 312 | " losses_list = [losses_list[j].avg for j in range(K)]\n", 313 | " top1_list = [top1_list[j].avg.cpu().numpy() for j in range(K)]\n", 314 | " \n", 315 | " return losses_list, top1_list\n", 316 | "\n", 317 | "\n", 318 | "def test(model):\n", 319 | " print('Testing:')\n", 320 | " # switch to evaluate mode\n", 321 | " model.eval()\n", 322 | " batch_time = AverageMeter()\n", 323 | " losses = AverageMeter()\n", 324 | " top1 = AverageMeter()\n", 325 | "\n", 326 | " end = time.time()\n", 327 | " with torch.no_grad():\n", 328 | " for i, (input, target) in enumerate(test_loader):\n", 329 | " input, target = input.to(device), target.to(device)\n", 330 | "\n", 331 | " # compute output\n", 332 | " _,_,_,_,output = model(input)\n", 333 | " loss = F.cross_entropy(output, target)\n", 334 | "\n", 335 | " output = output.float()\n", 336 | " loss = loss.float()\n", 337 | "\n", 338 | " # measure accuracy and record loss\n", 339 | " test_acc = accuracy(output.data, target.data)[0]\n", 340 | " losses.update(loss.item(), input.size(0))\n", 341 | " top1.update(test_acc, input.size(0))\n", 342 | "\n", 343 | " # measure elapsed time\n", 344 | " batch_time.update(time.time() - end)\n", 345 | " end = time.time()\n", 346 | "\n", 347 | " if i % args.print_freq == 0:\n", 348 | " log_out('Test: [{0}/{1}]\\t'\n", 349 | " 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n", 350 | " 'Loss {loss.val:.4f} ({loss.avg:.4f})\\t'\n", 351 | " 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(\n", 352 | " i, len(test_loader), batch_time=batch_time, loss=losses,\n", 353 | " top1=top1))\n", 354 | "\n", 355 | " log_out(' * {0} Prec@1 {top1.avg:.3f}'.format(model.model_name, top1=top1))\n", 356 | "\n", 357 | " return losses.avg, test_acc.cpu().numpy(), top1.avg.cpu().numpy()\n", 358 | "\n", 359 | "\n", 360 | "print('*-----------------DML----------------*')\n", 361 | "best_acc_list = [0] * K\n", 362 | "for epoch in range(1, args.epochs + 1):\n", 363 | " log_out(\"\\n===> epoch: {}/{}\".format(epoch, args.epochs))\n", 364 | "# log_out('current lr {:.5e}'.format(optimizer_1.param_groups[0]['lr']))\n", 365 | " for j in range(K):\n", 366 | " lr_scheduler_list[j].step()\n", 367 | " train_loss_list, top1_list = train(epoch, nets_list)\n", 368 | " # visaulize loss\n", 369 | " vis.line(np.column_stack(np.array(train_loss_list)), np.column_stack((epoch) * K), loss_win, update=\"append\")\n", 370 | " top1_list = []\n", 371 | " for j in range(K):\n", 372 | " _, _, top1 = test(nets_list[j])\n", 373 | " best_acc_list[j] = max(top1, best_acc_list[j])\n", 374 | " top1_list.append(top1)\n", 375 | " \n", 376 | " vis.line(np.column_stack(np.array(top1_list))), np.column_stack((epoch) * K), acc_win, update=\"append\")\n", 377 | " \n", 378 | "for j in range(K):\n", 379 | " log_out(\"@ [{}] BEST Prec: {:.4f}\".format(nets_list[j].model_name, best_acc_list[j]))\n", 380 | "log_out(\"--- {:.3f} mins ---\".format((time.time() - start_time)/60))\n" 381 | ] 382 | } 383 | ], 384 | "metadata": { 385 | "kernelspec": { 386 | "display_name": "Python [conda env:py37] *", 387 | "language": "python", 388 | "name": "conda-env-py37-py" 389 | }, 390 | "language_info": { 391 | "codemirror_mode": { 392 | "name": "ipython", 393 | "version": 3 394 | }, 395 | "file_extension": ".py", 396 | "mimetype": "text/x-python", 397 | "name": "python", 398 | "nbconvert_exporter": "python", 399 | "pygments_lexer": "ipython3", 400 | "version": "3.7.1" 401 | } 402 | }, 403 | "nbformat": 4, 404 | "nbformat_minor": 2 405 | } 406 | -------------------------------------------------------------------------------- /multi_teacher_avg_distill.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "start_time": "2019-08-18T10:59:47.828Z" 9 | } 10 | }, 11 | "outputs": [ 12 | { 13 | "name": "stderr", 14 | "output_type": "stream", 15 | "text": [ 16 | "/home/data/yaliu/jupyterbooks/multi-KD/models/teacher/resnet20.py:36: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.\n", 17 | " init.kaiming_normal(m.weight)\n", 18 | "WARNING:root:Setting up a new session...\n" 19 | ] 20 | }, 21 | { 22 | "name": "stdout", 23 | "output_type": "stream", 24 | "text": [ 25 | "Files already downloaded and verified\n", 26 | "StudentNet:\n", 27 | "\n", 28 | "ResNet(\n", 29 | " (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 30 | " (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 31 | " (layer1): Sequential(\n", 32 | " (0): BasicBlock(\n", 33 | " (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 34 | " (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 35 | " (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 36 | " (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 37 | " (shortcut): Sequential()\n", 38 | " )\n", 39 | " (1): BasicBlock(\n", 40 | " (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 41 | " (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 42 | " (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 43 | " (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 44 | " (shortcut): Sequential()\n", 45 | " )\n", 46 | " (2): BasicBlock(\n", 47 | " (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 48 | " (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 49 | " (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 50 | " (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 51 | " (shortcut): Sequential()\n", 52 | " )\n", 53 | " )\n", 54 | " (layer2): Sequential(\n", 55 | " (0): BasicBlock(\n", 56 | " (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 57 | " (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 58 | " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 59 | " (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 60 | " (shortcut): LambdaLayer()\n", 61 | " )\n", 62 | " (1): BasicBlock(\n", 63 | " (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 64 | " (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 65 | " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 66 | " (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 67 | " (shortcut): Sequential()\n", 68 | " )\n", 69 | " (2): BasicBlock(\n", 70 | " (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 71 | " (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 72 | " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 73 | " (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 74 | " (shortcut): Sequential()\n", 75 | " )\n", 76 | " )\n", 77 | " (layer3): Sequential(\n", 78 | " (0): BasicBlock(\n", 79 | " (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 80 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 81 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 82 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 83 | " (shortcut): LambdaLayer()\n", 84 | " )\n", 85 | " (1): BasicBlock(\n", 86 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 87 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 88 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 89 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 90 | " (shortcut): Sequential()\n", 91 | " )\n", 92 | " (2): BasicBlock(\n", 93 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 94 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 95 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 96 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 97 | " (shortcut): Sequential()\n", 98 | " )\n", 99 | " )\n", 100 | " (linear): Linear(in_features=64, out_features=100, bias=True)\n", 101 | ")\n", 102 | "\n", 103 | "===> epoch: 1/200\n", 104 | "current lr 1.00000e-01\n", 105 | "Training:\n" 106 | ] 107 | }, 108 | { 109 | "name": "stderr", 110 | "output_type": "stream", 111 | "text": [ 112 | "/home/yaliu/Dev/anaconda3/envs/py37/lib/python3.7/site-packages/ipykernel_launcher.py:260: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", 113 | "/home/yaliu/Dev/anaconda3/envs/py37/lib/python3.7/site-packages/ipykernel_launcher.py:146: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n", 114 | "/home/yaliu/Dev/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/functional.py:1992: UserWarning: reduction: 'mean' divides the total loss by both the batch size and the support size.'batchmean' divides only by the batch size, and aligns with the KL div math definition.'mean' will be changed to behave the same as 'batchmean' in the next major release.\n", 115 | " warnings.warn(\"reduction: 'mean' divides the total loss by both the batch size and the support size.\"\n" 116 | ] 117 | }, 118 | { 119 | "name": "stdout", 120 | "output_type": "stream", 121 | "text": [ 122 | "[0/391]\tTime 0.889 (0.889)\tData 0.041 (0.041)\tLoss 2.2553 (2.2553)\tPrec@1 0.781 (0.781)\n" 123 | ] 124 | } 125 | ], 126 | "source": [ 127 | "### from __future__ import print_function\n", 128 | "import os\n", 129 | "import time\n", 130 | "import logging\n", 131 | "import argparse\n", 132 | "import random\n", 133 | "import numpy as np\n", 134 | "from visdom import Visdom\n", 135 | "import torch\n", 136 | "import torch.nn as nn\n", 137 | "import torch.nn.functional as F\n", 138 | "import torch.optim as optim\n", 139 | "from torch.autograd import Variable\n", 140 | "from torch.utils.data import DataLoader\n", 141 | "from torchvision import datasets, transforms\n", 142 | "from utils import *\n", 143 | "\n", 144 | "# Teacher models:\n", 145 | "# VGG11/VGG13/VGG16/VGG19, GoogLeNet, AlxNet, ResNet18, ResNet34, \n", 146 | "# ResNet50, ResNet101, ResNet152, ResNeXt29_2x64d, ResNeXt29_4x64d, \n", 147 | "# ResNeXt29_8x64d, ResNeXt29_32x64d, PreActResNet18, PreActResNet34, \n", 148 | "# PreActResNet50, PreActResNet101, PreActResNet152, \n", 149 | "# DenseNet121, DenseNet161, DenseNet169, DenseNet201, \n", 150 | "import models\n", 151 | "\n", 152 | "# Student models:\n", 153 | "# myNet, LeNet, FitNet\n", 154 | "\n", 155 | "start_time = time.time()\n", 156 | "# os.makedirs('./checkpoint', exist_ok=True)\n", 157 | "\n", 158 | "# Training settings\n", 159 | "parser = argparse.ArgumentParser(description='PyTorch multi_teacher_avg_distill')\n", 160 | "\n", 161 | "parser.add_argument('--dataset',\n", 162 | " choices=['CIFAR10',\n", 163 | " 'CIFAR100'\n", 164 | " ],\n", 165 | " default='CIFAR10')\n", 166 | "parser.add_argument('--teachers',\n", 167 | " choices=['ResNet32',\n", 168 | " 'ResNet50',\n", 169 | " 'ResNet56',\n", 170 | " 'ResNet110'\n", 171 | " ],\n", 172 | " default=['ResNet32', 'ResNet56', 'ResNet110'],\n", 173 | " nargs='+')\n", 174 | "parser.add_argument('--student',\n", 175 | " choices=['ResNet20',\n", 176 | " 'myNet'\n", 177 | " ],\n", 178 | " default='ResNet20')\n", 179 | "\n", 180 | "parser.add_argument('--n_class', type=int, default=10, metavar='N', help='num of classes')\n", 181 | "parser.add_argument('--T', type=float, default=20.0, metavar='Temputure', help='Temputure for distillation')\n", 182 | "parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size for training')\n", 183 | "parser.add_argument('--test_batch_size', type=int, default=128, metavar='N', help='input test batch size for training')\n", 184 | "parser.add_argument('--epochs', type=int, default=20, metavar='N', help='number of epochs to train (default: 20)')\n", 185 | "parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 0.01)')\n", 186 | "parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.5)')\n", 187 | "parser.add_argument('--device', default='cuda:1', type=str, help='device: cuda or cpu')\n", 188 | "parser.add_argument('--print_freq', type=int, default=10, metavar='N', help='how many batches to wait before logging training status')\n", 189 | "\n", 190 | "config = ['--dataset', 'CIFAR100', '--n_class', '100', '--epochs', '200', '--T', '5.0', '--device', 'cuda:1']\n", 191 | "args = parser.parse_args(config)\n", 192 | "\n", 193 | "device = args.device if torch.cuda.is_available() else 'cpu'\n", 194 | "load_dir = './checkpoint/' + args.dataset + '/'\n", 195 | "\n", 196 | "# teachers model\n", 197 | "teacher_models = []\n", 198 | "for te in args.teachers:\n", 199 | " te_model = getattr(models, te)(num_classes=args.n_class)\n", 200 | "# print(te_model)\n", 201 | " te_model.load_state_dict(torch.load(load_dir + te_model.model_name + '.pth'))\n", 202 | " te_model.to(device)\n", 203 | " te_model.eval() # eval mode\n", 204 | " teacher_models.append(te_model)\n", 205 | "\n", 206 | "st_model = getattr(models, args.student)(num_classes=args.n_class) # args.student()\n", 207 | "st_model.to(device)\n", 208 | "\n", 209 | "# logging\n", 210 | "logfile = load_dir + 'avg_distill_' + st_model.model_name + '.log'\n", 211 | "if os.path.exists(logfile):\n", 212 | " os.remove(logfile)\n", 213 | "def log_out(info):\n", 214 | " f = open(logfile, mode='a')\n", 215 | " f.write(info)\n", 216 | " f.write('\\n')\n", 217 | " f.close()\n", 218 | " print(info)\n", 219 | " \n", 220 | "# visualizer\n", 221 | "vis = Visdom(env='distill')\n", 222 | "loss_win = vis.line(\n", 223 | " X=np.array([0]),\n", 224 | " Y=np.array([0]),\n", 225 | " opts=dict(\n", 226 | " title='multi avg. Loss',\n", 227 | " xlabel='epoch',\n", 228 | " xtickmin=0,\n", 229 | " ylabel='loss',\n", 230 | " ytickmin=0,\n", 231 | " ),\n", 232 | " name=\"loss\"\n", 233 | ")\n", 234 | "\n", 235 | "acc_win = vis.line(\n", 236 | " X=np.column_stack((0, 0)),\n", 237 | " Y=np.column_stack((0, 0)),\n", 238 | " opts=dict(\n", 239 | " title='multi-KD avg. Acc',\n", 240 | " xlabel='epoch',\n", 241 | " xtickmin=0,\n", 242 | " ylabel='accuracy',\n", 243 | " ytickmin=0,\n", 244 | " ytickmax=100,\n", 245 | " legend=['train_acc', 'test_acc']\n", 246 | " ),\n", 247 | " name=\"acc\"\n", 248 | ")\n", 249 | "\n", 250 | "\n", 251 | "# data\n", 252 | "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", 253 | "train_transform = transforms.Compose([\n", 254 | " transforms.RandomHorizontalFlip(),\n", 255 | " transforms.RandomCrop(32, 4),\n", 256 | " transforms.ToTensor(),\n", 257 | " normalize,\n", 258 | "])\n", 259 | "test_transform = transforms.Compose([transforms.ToTensor(), normalize])\n", 260 | "train_set = getattr(datasets, args.dataset)(root='../data', train=True, download=True, transform=train_transform)\n", 261 | "test_set = getattr(datasets, args.dataset)(root='../data', train=False, download=False, transform=test_transform)\n", 262 | "train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)\n", 263 | "test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False)\n", 264 | "\n", 265 | "# optimizer = optim.SGD(st_model.parameters(), lr=args.lr, momentum=args.momentum)\n", 266 | "optimizer = optim.Adam(st_model.parameters(), lr=args.lr)\n", 267 | "optimizer_sgd = optim.SGD(st_model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)\n", 268 | "lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer_sgd, milestones=[100, 150])\n", 269 | "\n", 270 | "# avg diatill\n", 271 | "def distillation_loss(y, labels, logits, T, alpha=0.7):\n", 272 | " return nn.KLDivLoss()(F.log_softmax(y/T), logits) * (T*T * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)\n", 273 | "\n", 274 | "# triplet loss\n", 275 | "triplet_loss = nn.TripletMarginLoss(margin=0.2, p=2).to(device)\n", 276 | "\n", 277 | "# get max infoentropy scores\n", 278 | "# input: Tensor[3, 128, 10]\n", 279 | "def maxInfo_logits(te_scores_Tensor):\n", 280 | " used_score = torch.FloatTensor(te_scores_Tensor.size(1), te_scores_Tensor.size(2)).to(device)\n", 281 | " ents = torch.FloatTensor(te_scores_Tensor.size(0), te_scores_Tensor.size(1)).to(device)\n", 282 | " logp = torch.log2(te_scores_Tensor)\n", 283 | " plogp = -logp.mul(te_scores_Tensor)\n", 284 | " for i,te in enumerate(plogp):\n", 285 | " ents[i] = torch.sum(te, dim=1)\n", 286 | " max_ent_index = torch.max(ents, dim=0).indices # 取每一列最大值index\n", 287 | "# print(max_ent_index)\n", 288 | " for i in range(max_ent_index.size(0)):\n", 289 | " used_score[i] = te_scores_Tensor[max_ent_index[i].item()][i]\n", 290 | "# print(used_score)\n", 291 | "\n", 292 | " return used_score\n", 293 | " \n", 294 | "# avg logits\n", 295 | "# input: Tensor[3, 128, 10]\n", 296 | "def avg_logits(te_scores_Tensor):\n", 297 | "# print(te_scores_Tensor.size())\n", 298 | " mean_Tensor = torch.mean(te_scores_Tensor, dim=1)\n", 299 | "# print(mean_Tensor)\n", 300 | " return mean_Tensor\n", 301 | " \n", 302 | "# random logits\n", 303 | "def random_logits(te_scores_Tensor):\n", 304 | " return te_scores_Tensor[np.random.randint(0, 1, 1)]\n", 305 | "\n", 306 | "# input: t1, t2 - triplet pair\n", 307 | "def triplet_distance(t1, t2):\n", 308 | " return (t1 - t2).pow(2).sum()\n", 309 | " \n", 310 | "# get triplets\n", 311 | "def random_triplets(st_maps, te_maps):\n", 312 | " conflict = 0\n", 313 | " st_triplet_list = []\n", 314 | " triplet_set_size = st_maps.size(0)\n", 315 | " batch_list = [x for x in range(triplet_set_size)]\n", 316 | " for i in range(triplet_set_size):\n", 317 | " triplet_index = random.sample(batch_list, 3)\n", 318 | " anchor_index = triplet_index[0] # denote the 1st triplet item as anchor\n", 319 | " st_triplet = st_maps[triplet_index]\n", 320 | " te_triplet = te_maps[triplet_index]\n", 321 | " distance_01 = triplet_distance(te_triplet[0], te_triplet[1])\n", 322 | " distance_02 = triplet_distance(te_triplet[0], te_triplet[2])\n", 323 | " if distance_01 > distance_02:\n", 324 | " conflict += 1\n", 325 | " # swap postive and negative\n", 326 | " st_triplet[1], st_triplet[2] = st_triplet[2], st_triplet[1]\n", 327 | " st_triplet_list.append(st_triplet)\n", 328 | " \n", 329 | " st_triplet_batch = torch.stack(st_triplet_list, dim=1)\n", 330 | " return st_triplet_batch\n", 331 | " \n", 332 | "# get the smallest conflicts index\n", 333 | "def smallest_conflict_teacher(st_maps, te_maps_list):\n", 334 | " \n", 335 | " index = 0\n", 336 | " triplet_set_size = st_maps.size(0)\n", 337 | " min_conflict = 1\n", 338 | " batch_list = [x for x in range(triplet_set_size)]\n", 339 | " triplet_index = random.sample(batch_list, 3)\n", 340 | " anchor_index = triplet_index[0] # denote the 1st triplet item as anchor\n", 341 | " for idx, te_maps in enumerate(te_maps_list):\n", 342 | " conflict = 0\n", 343 | " for i in range(triplet_set_size):\n", 344 | " st_triplet = st_maps[triplet_index]\n", 345 | " te_triplet = te_maps[triplet_index]\n", 346 | " distance_01 = triplet_distance(te_triplet[0], te_triplet[1])\n", 347 | " distance_02 = triplet_distance(te_triplet[0], te_triplet[2])\n", 348 | " if distance_01 > distance_02:\n", 349 | " conflict += 1\n", 350 | " conflict /= triplet_set_size\n", 351 | " conflict = min(conflict, (1-conflict))\n", 352 | " if conflict < min_conflict:\n", 353 | " index = idx\n", 354 | " return index\n", 355 | "\n", 356 | "# train with multi-teacher\n", 357 | "def train(epoch, st_model):\n", 358 | " print('Training:')\n", 359 | " # switch to train mode\n", 360 | " st_model.train()\n", 361 | " batch_time = AverageMeter()\n", 362 | " data_time = AverageMeter()\n", 363 | " losses = AverageMeter()\n", 364 | " top1 = AverageMeter()\n", 365 | " \n", 366 | " end = time.time()\n", 367 | " for i, (input, target) in enumerate(train_loader):\n", 368 | "\n", 369 | " # measure data loading time\n", 370 | " data_time.update(time.time() - end)\n", 371 | "\n", 372 | " input, target = input.to(device), target.to(device)\n", 373 | " \n", 374 | " # compute outputs\n", 375 | " b1, b2, b3, pool, output = st_model(input)\n", 376 | " st_maps = [b1, b2, b3, pool]\n", 377 | " \n", 378 | " te_scores_list = []\n", 379 | " hint_maps = []\n", 380 | " for j,te in enumerate(teacher_models):\n", 381 | " te.eval()\n", 382 | " with torch.no_grad():\n", 383 | " t_b1, t_b2, t_b3, t_pool, t_output = te(input)\n", 384 | " \n", 385 | " hint_maps.append(t_b2)\n", 386 | " t_output = F.softmax(t_output/args.T)\n", 387 | " te_scores_list.append(t_output)\n", 388 | " te_scores_Tensor = torch.stack(te_scores_list, dim=1) # size: [128, 3, 10]\n", 389 | " mean_logits = avg_logits(te_scores_Tensor)\n", 390 | " \n", 391 | " \n", 392 | " te_index = smallest_conflict_teacher(b2, hint_maps)\n", 393 | " st_tripets = random_triplets(b2, hint_maps[te_index])\n", 394 | " \n", 395 | " optimizer_sgd.zero_grad()\n", 396 | "\n", 397 | " # compute gradient and do SGD step\n", 398 | " kd_loss = distillation_loss(output, target, mean_logits, T=args.T, alpha=0.7)\n", 399 | " relation_loss = triplet_loss(st_tripets[0], st_tripets[1], st_tripets[2])\n", 400 | " \n", 401 | " loss = kd_loss + relation_loss\n", 402 | "\n", 403 | " loss.backward(retain_graph=True)\n", 404 | " optimizer_sgd.step()\n", 405 | "\n", 406 | " output = output.float()\n", 407 | " loss = loss.float()\n", 408 | " # measure accuracy and record loss\n", 409 | " train_acc = accuracy(output.data, target.data)[0]\n", 410 | " losses.update(loss.item(), input.size(0))\n", 411 | " top1.update(train_acc, input.size(0))\n", 412 | "\n", 413 | " # measure elapsed time\n", 414 | " batch_time.update(time.time() - end)\n", 415 | " end = time.time()\n", 416 | "\n", 417 | " if i % args.print_freq == 0:\n", 418 | " log_out('[{0}/{1}]\\t'\n", 419 | " 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n", 420 | " 'Data {data_time.val:.3f} ({data_time.avg:.3f})\\t'\n", 421 | " 'Loss {loss.val:.4f} ({loss.avg:.4f})\\t'\n", 422 | " 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(\n", 423 | " i, len(train_loader), batch_time=batch_time,\n", 424 | " data_time=data_time, loss=losses, top1=top1))\n", 425 | " return losses.avg, train_acc.cpu().numpy()\n", 426 | "\n", 427 | "\n", 428 | "def test(model):\n", 429 | " print('Testing:')\n", 430 | " # switch to evaluate mode\n", 431 | " model.eval()\n", 432 | " batch_time = AverageMeter()\n", 433 | " losses = AverageMeter()\n", 434 | " top1 = AverageMeter()\n", 435 | "\n", 436 | " end = time.time()\n", 437 | " with torch.no_grad():\n", 438 | " for i, (input, target) in enumerate(test_loader):\n", 439 | " input, target = input.to(device), target.to(device)\n", 440 | "\n", 441 | " # compute output\n", 442 | " _,_,_,_,output = model(input)\n", 443 | " loss = F.cross_entropy(output, target)\n", 444 | "\n", 445 | " output = output.float()\n", 446 | " loss = loss.float()\n", 447 | "\n", 448 | " # measure accuracy and record loss\n", 449 | " test_acc = accuracy(output.data, target.data)[0]\n", 450 | " losses.update(loss.item(), input.size(0))\n", 451 | " top1.update(test_acc, input.size(0))\n", 452 | "\n", 453 | " # measure elapsed time\n", 454 | " batch_time.update(time.time() - end)\n", 455 | " end = time.time()\n", 456 | "\n", 457 | " if i % args.print_freq == 0:\n", 458 | " log_out('Test: [{0}/{1}]\\t'\n", 459 | " 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n", 460 | " 'Loss {loss.val:.4f} ({loss.avg:.4f})\\t'\n", 461 | " 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(\n", 462 | " i, len(test_loader), batch_time=batch_time, loss=losses,\n", 463 | " top1=top1))\n", 464 | "\n", 465 | " log_out(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))\n", 466 | "\n", 467 | " return losses.avg, test_acc.cpu().numpy(), top1.avg.cpu().numpy()\n", 468 | "\n", 469 | "\n", 470 | "print('StudentNet:\\n')\n", 471 | "print(st_model)\n", 472 | "best_acc = 0\n", 473 | "for epoch in range(1, args.epochs + 1):\n", 474 | " log_out(\"\\n===> epoch: {}/{}\".format(epoch, args.epochs))\n", 475 | " log_out('current lr {:.5e}'.format(optimizer_sgd.param_groups[0]['lr']))\n", 476 | " lr_scheduler.step()\n", 477 | " train_loss, train_acc = train(epoch, st_model)\n", 478 | " # visaulize loss\n", 479 | " vis.line(np.array([train_loss]), np.array([epoch]), loss_win, update=\"append\")\n", 480 | " _, test_acc, top1 = test(st_model)\n", 481 | " vis.line(np.column_stack((train_acc, top1)), np.column_stack((epoch, epoch)), acc_win, update=\"append\")\n", 482 | " if top1 > best_acc:\n", 483 | " best_acc = top1\n", 484 | " if epoch > 150:\n", 485 | " torch.save(st_model.state_dict(), load_dir + st_model.model_name + '_avg.pth')\n", 486 | "\n", 487 | "# release GPU memory\n", 488 | "torch.cuda.empty_cache()\n", 489 | "log_out(\"@ BEST ACC = {:.4f}%\".format(best_acc))\n", 490 | "log_out(\"--- {:.3f} mins ---\".format((time.time() - start_time)/60))\n" 491 | ] 492 | } 493 | ], 494 | "metadata": { 495 | "kernelspec": { 496 | "display_name": "Python [conda env:py37] *", 497 | "language": "python", 498 | "name": "conda-env-py37-py" 499 | }, 500 | "language_info": { 501 | "codemirror_mode": { 502 | "name": "ipython", 503 | "version": 3 504 | }, 505 | "file_extension": ".py", 506 | "mimetype": "text/x-python", 507 | "name": "python", 508 | "nbconvert_exporter": "python", 509 | "pygments_lexer": "ipython3", 510 | "version": "3.7.1" 511 | } 512 | }, 513 | "nbformat": 4, 514 | "nbformat_minor": 2 515 | } 516 | -------------------------------------------------------------------------------- /adaptive_FitNet_teacher-level.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "start_time": "2019-08-19T14:16:34.572Z" 9 | } 10 | }, 11 | "outputs": [ 12 | { 13 | "name": "stderr", 14 | "output_type": "stream", 15 | "text": [ 16 | "/home/data/yaliu/jupyterbooks/multi-KD/models/teacher/resnet20.py:36: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_.\n", 17 | " init.kaiming_normal(m.weight)\n", 18 | "WARNING:root:Setting up a new session...\n" 19 | ] 20 | }, 21 | { 22 | "name": "stdout", 23 | "output_type": "stream", 24 | "text": [ 25 | "Files already downloaded and verified\n", 26 | "StudentNet:\n", 27 | "\n", 28 | "ResNet(\n", 29 | " (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 30 | " (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 31 | " (layer1): Sequential(\n", 32 | " (0): BasicBlock(\n", 33 | " (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 34 | " (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 35 | " (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 36 | " (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 37 | " (shortcut): Sequential()\n", 38 | " )\n", 39 | " (1): BasicBlock(\n", 40 | " (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 41 | " (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 42 | " (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 43 | " (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 44 | " (shortcut): Sequential()\n", 45 | " )\n", 46 | " (2): BasicBlock(\n", 47 | " (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 48 | " (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 49 | " (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 50 | " (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 51 | " (shortcut): Sequential()\n", 52 | " )\n", 53 | " )\n", 54 | " (layer2): Sequential(\n", 55 | " (0): BasicBlock(\n", 56 | " (conv1): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 57 | " (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 58 | " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 59 | " (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 60 | " (shortcut): LambdaLayer()\n", 61 | " )\n", 62 | " (1): BasicBlock(\n", 63 | " (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 64 | " (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 65 | " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 66 | " (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 67 | " (shortcut): Sequential()\n", 68 | " )\n", 69 | " (2): BasicBlock(\n", 70 | " (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 71 | " (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 72 | " (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 73 | " (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 74 | " (shortcut): Sequential()\n", 75 | " )\n", 76 | " )\n", 77 | " (layer3): Sequential(\n", 78 | " (0): BasicBlock(\n", 79 | " (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n", 80 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 81 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 82 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 83 | " (shortcut): LambdaLayer()\n", 84 | " )\n", 85 | " (1): BasicBlock(\n", 86 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 87 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 88 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 89 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 90 | " (shortcut): Sequential()\n", 91 | " )\n", 92 | " (2): BasicBlock(\n", 93 | " (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 94 | " (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 95 | " (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", 96 | " (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", 97 | " (shortcut): Sequential()\n", 98 | " )\n", 99 | " )\n", 100 | " (linear): Linear(in_features=64, out_features=10, bias=True)\n", 101 | ")\n", 102 | "Training adapter:\n" 103 | ] 104 | }, 105 | { 106 | "name": "stderr", 107 | "output_type": "stream", 108 | "text": [ 109 | "/home/yaliu/Dev/anaconda3/envs/py37/lib/python3.7/site-packages/ipykernel_launcher.py:246: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", 110 | "/home/yaliu/Dev/anaconda3/envs/py37/lib/python3.7/site-packages/ipykernel_launcher.py:209: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", 111 | "/home/yaliu/Dev/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/_reduction.py:15: UserWarning: reduction='elementwise_mean' is deprecated, please use reduction='mean' instead.\n", 112 | " warnings.warn(\"reduction='elementwise_mean' is deprecated, please use reduction='mean' instead.\")\n", 113 | "/home/yaliu/Dev/anaconda3/envs/py37/lib/python3.7/site-packages/ipykernel_launcher.py:180: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n", 114 | "/home/yaliu/Dev/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/functional.py:1992: UserWarning: reduction: 'mean' divides the total loss by both the batch size and the support size.'batchmean' divides only by the batch size, and aligns with the KL div math definition.'mean' will be changed to behave the same as 'batchmean' in the next major release.\n", 115 | " warnings.warn(\"reduction: 'mean' divides the total loss by both the batch size and the support size.\"\n" 116 | ] 117 | } 118 | ], 119 | "source": [ 120 | "from __future__ import print_function\n", 121 | "import os\n", 122 | "import time\n", 123 | "import logging\n", 124 | "import argparse\n", 125 | "import numpy as np\n", 126 | "from visdom import Visdom\n", 127 | "from PIL import Image\n", 128 | "import torch\n", 129 | "import torch.nn as nn\n", 130 | "import torch.nn.functional as F\n", 131 | "import torch.optim as optim\n", 132 | "from torch.autograd import Variable\n", 133 | "from torch.utils.data import DataLoader\n", 134 | "from torchvision import datasets, transforms\n", 135 | "from utils import *\n", 136 | "from metric.loss import FitNet, AttentionTransfer, RKdAngle, RkdDistance\n", 137 | "\n", 138 | "# Teacher models:\n", 139 | "# VGG11/VGG13/VGG16/VGG19, GoogLeNet, AlxNet, ResNet18, ResNet34, \n", 140 | "# ResNet50, ResNet101, ResNet152, ResNeXt29_2x64d, ResNeXt29_4x64d, \n", 141 | "# ResNeXt29_8x64d, ResNeXt29_32x64d, PreActResNet18, PreActResNet34, \n", 142 | "# PreActhttps://www.bing.com/?mkt=zh-CNResNet50, PreActResNet101, PreActResNet152, \n", 143 | "# DenseNet121, DenseNet161, DenseNet169, DenseNet201, \n", 144 | "import models\n", 145 | "\n", 146 | "# Student models:\n", 147 | "# myNet, LeNet, FitNet\n", 148 | "\n", 149 | "start_time = time.time()\n", 150 | "# os.makedirs('./checkpoint', exist_ok=True)\n", 151 | "\n", 152 | "# Training settings\n", 153 | "parser = argparse.ArgumentParser(description='PyTorch ada. FitNet')\n", 154 | "\n", 155 | "parser.add_argument('--dataset',\n", 156 | " choices=['CIFAR10',\n", 157 | " 'CIFAR100'\n", 158 | " ],\n", 159 | " default='CIFAR10')\n", 160 | "parser.add_argument('--teachers',\n", 161 | " choices=['ResNet32',\n", 162 | " 'ResNet50',\n", 163 | " 'ResNet56',\n", 164 | " 'ResNet110',\n", 165 | " 'DenseNet121'\n", 166 | " ],\n", 167 | " default=['ResNet32', 'ResNet56', 'ResNet110'],\n", 168 | " nargs='+')\n", 169 | "parser.add_argument('--student',\n", 170 | " choices=['ResNet20',\n", 171 | " 'myNet'\n", 172 | " ],\n", 173 | " default='ResNet20')\n", 174 | "\n", 175 | "parser.add_argument('--kd_ratio', default=0.7, type=float)\n", 176 | "parser.add_argument('--n_class', type=int, default=10, metavar='N', help='num of classes')\n", 177 | "parser.add_argument('--T', type=float, default=20.0, metavar='Temputure', help='Temputure for distillation')\n", 178 | "parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='input batch size for training')\n", 179 | "parser.add_argument('--test_batch_size', type=int, default=128, metavar='N', help='input test batch size for training')\n", 180 | "parser.add_argument('--epochs', type=int, default=20, metavar='N', help='number of epochs to train (default: 20)')\n", 181 | "parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 0.01)')\n", 182 | "parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.5)')\n", 183 | "parser.add_argument('--device', default='cuda:0', type=str, help='device: cuda or cpu')\n", 184 | "parser.add_argument('--print_freq', type=int, default=10, metavar='N', help='how many batches to wait before logging training status')\n", 185 | "\n", 186 | "config = ['--epochs', '200', '--T', '5.0', '--device', 'cuda:0']\n", 187 | "args = parser.parse_args(config)\n", 188 | "\n", 189 | "device = args.device if torch.cuda.is_available() else 'cpu'\n", 190 | "load_dir = './checkpoint/' + args.dataset + '/'\n", 191 | "\n", 192 | "# teachers model\n", 193 | "teacher_models = []\n", 194 | "for te in args.teachers:\n", 195 | " te_model = getattr(models, te)(num_classes=args.n_class)\n", 196 | "# print(te_model)\n", 197 | " te_model.load_state_dict(torch.load(load_dir + te_model.model_name + '.pth'))\n", 198 | " te_model.to(device)\n", 199 | " teacher_models.append(te_model)\n", 200 | "\n", 201 | "st_model = getattr(models, args.student)(num_classes=args.n_class) # args.student()\n", 202 | "st_model.to(device)\n", 203 | "\n", 204 | "# logging\n", 205 | "logfile = load_dir + 'ada_te_fitnet_' + st_model.model_name + '.log'\n", 206 | "if os.path.exists(logfile):\n", 207 | " os.remove(logfile)\n", 208 | "def log_out(info):\n", 209 | " f = open(logfile, mode='a')\n", 210 | " f.write(info)\n", 211 | " f.write('\\n')\n", 212 | " f.close()\n", 213 | " print(info)\n", 214 | " \n", 215 | "# visualizer\n", 216 | "vis = Visdom(env='distill')\n", 217 | "loss_win = vis.line(\n", 218 | " X=np.array([0]),\n", 219 | " Y=np.array([0]),\n", 220 | " opts=dict(\n", 221 | " title='FitNet ada. loss',\n", 222 | " xtickmin=0,\n", 223 | "# xtickmax=1,\n", 224 | "# xtickstep=5,\n", 225 | " ytickmin=0,\n", 226 | "# ytickmax=1,\n", 227 | " ytickstep=0.5,\n", 228 | "# markers=True,\n", 229 | "# markersymbol='dot',\n", 230 | "# markersize=5,\n", 231 | " ),\n", 232 | " name=\"loss\"\n", 233 | ")\n", 234 | "\n", 235 | "acc_win = vis.line(\n", 236 | " X=np.column_stack((0, 0)),\n", 237 | " Y=np.column_stack((0, 0)),\n", 238 | " opts=dict(\n", 239 | " title='FitNet ada. ACC',\n", 240 | " xtickmin=0,\n", 241 | "# xtickstep=5,\n", 242 | " ytickmin=0,\n", 243 | " ytickmax=100,\n", 244 | "# markers=True,\n", 245 | "# markersymbol='dot',\n", 246 | "# markersize=5,\n", 247 | " legend=['train_acc', 'test_acc']\n", 248 | " ),\n", 249 | " name=\"acc\"\n", 250 | ")\n", 251 | "\n", 252 | "# data\n", 253 | "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", 254 | "train_transform = transforms.Compose([\n", 255 | " transforms.RandomHorizontalFlip(),\n", 256 | " transforms.RandomCrop(32, 4),\n", 257 | " transforms.ToTensor(),\n", 258 | " normalize,\n", 259 | "])\n", 260 | "test_transform = transforms.Compose([transforms.ToTensor(), normalize])\n", 261 | "train_set = getattr(datasets, args.dataset)(root='../data', train=True, download=True, transform=train_transform)\n", 262 | "test_set = getattr(datasets, args.dataset)(root='../data', train=False, download=False, transform=test_transform)\n", 263 | "train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)\n", 264 | "test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False)\n", 265 | "# optim\n", 266 | "optimizer_W = optim.SGD([adapter.W], lr=args.lr, momentum=0.9)\n", 267 | "optimizer_theta = optim.SGD([adapter.theta], lr=args.lr, momentum=0.9)\n", 268 | "optimizer_sgd = optim.SGD(st_model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)\n", 269 | "lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer_sgd, gamma=0.1, milestones=[100, 150])\n", 270 | "lr_scheduler2 = optim.lr_scheduler.MultiStepLR(optimizer_W, milestones=[40, 50])\n", 271 | "lr_scheduler3 = optim.lr_scheduler.MultiStepLR(optimizer_theta, milestones=[40, 50])\n", 272 | "\n", 273 | "# loss\n", 274 | "def kd_criterion(y, labels, weighted_logits, T=10.0, alpha=0.7):\n", 275 | " return nn.KLDivLoss()(F.log_softmax(y/T), weighted_logits) * (T*T * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)\n", 276 | "\n", 277 | "dist_criterion = RkdDistance().to(device)\n", 278 | "angle_criterion = RKdAngle().to(device)\n", 279 | "fitnet_criterion = [FitNet(32, 64), FitNet(64, 64),FitNet(64, 64)]\n", 280 | "[f.to(device) for f in fitnet_criterion]\n", 281 | "\n", 282 | "# adapter model\n", 283 | "class Adapter():\n", 284 | " def __init__(self, in_models, pool_size):\n", 285 | " # representations of teachers\n", 286 | " pool_ch = pool_size[1] # 64\n", 287 | " pool_w = pool_size[2] # 8\n", 288 | " LR_list = []\n", 289 | " torch.manual_seed(1)\n", 290 | " self.theta = torch.randn(len(in_models), pool_ch).to(device) # [3, 64]\n", 291 | " self.theta.requires_grad_(True)\n", 292 | " \n", 293 | " self.max_feat = nn.MaxPool2d(kernel_size=(pool_w, pool_w), stride=pool_w).to(device)\n", 294 | " self.W = torch.randn(pool_ch, 1).to(device)\n", 295 | " self.W.requires_grad_(True)\n", 296 | " self.val = False\n", 297 | "\n", 298 | " def loss(self, y, labels, weighted_logits, T=10.0, alpha=0.7):\n", 299 | " ls = nn.KLDivLoss()(F.log_softmax(y/T), weighted_logits) * (T*T * 2.0 * alpha) + F.cross_entropy(y, labels) * (1. - alpha)\n", 300 | " if not self.val:\n", 301 | " ls += 0.1 * (torch.sum(self.W * self.W) + torch.sum(torch.sum(self.theta * self.theta, dim=1), dim=0))\n", 302 | " return ls\n", 303 | " \n", 304 | " def gradient(self, lr=0.01):\n", 305 | " self.W.data = self.W.data - lr * self.W.grad.data\n", 306 | " # Manually zero the gradients after updating weights\n", 307 | " self.W.grad.data.zero_()\n", 308 | " \n", 309 | " def eval(self):\n", 310 | " self.val = True\n", 311 | " self.theta.detach()\n", 312 | " self.W.detach()\n", 313 | " \n", 314 | " # input size: [64, 8, 8], [128, 3, 10]\n", 315 | " def forward(self, conv_map, te_logits_list):\n", 316 | " beta = self.max_feat(conv_map)\n", 317 | " beta = torch.squeeze(beta) # [128, 64]\n", 318 | " \n", 319 | " latent_factor = []\n", 320 | " for t in self.theta:\n", 321 | " latent_factor.append(beta * t)\n", 322 | "# latent_factor = torch.stack(latent_factor, dim=0) # [3, 128, 64]\n", 323 | " alpha = []\n", 324 | " for lf in latent_factor: # lf.size:[128, 64]\n", 325 | " alpha.append(lf.mm(self.W))\n", 326 | " alpha = torch.stack(alpha, dim=0) # [3, 128, 1]\n", 327 | " alpha = torch.squeeze(alpha).transpose(0, 1) # [128, 3]\n", 328 | " weight = F.softmax(alpha) # [128, 3]\n", 329 | "\n", 330 | " return weight\n", 331 | "\n", 332 | "# adapter instance\n", 333 | "_,_,_,pool_m,_ = st_model(torch.randn(1,3, 128, 128).to(device)) # get pool_size of student\n", 334 | "# reate adapter instance\n", 335 | "adapter = Adapter(teacher_models, pool_m.size())\n", 336 | "\n", 337 | "\n", 338 | "def teacher_weights(n_epochs=50, model=st_model):\n", 339 | " print('Training adapter:')\n", 340 | " start_time = time.time()\n", 341 | " model.train()\n", 342 | " adapter.eval()\n", 343 | " for ep in range(n_epochs):\n", 344 | " lr_scheduler2.step()\n", 345 | " lr_scheduler3.step()\n", 346 | " for i, (input, target) in enumerate(train_loader):\n", 347 | "\n", 348 | " input, target = input.to(device), target.to(device)\n", 349 | " # compute outputs\n", 350 | " b1, b2, b3, pool, output = model(input) # out_feat: 16, 32, 64, 64, - \n", 351 | " st_maps = [b1, b2, b3, pool]\n", 352 | "# print('b1:{}, b2:{}, b3{}, pool:{}'.format(b1.size(), b2.size(), b3.size(), pool.size()))\n", 353 | "# b1:torch.Size([128, 16, 32, 32]), b2:torch.Size([128, 32, 16, 16]), b3torch.Size([128, 64, 8, 8]), pool:torch.Size([128, 64, 1, 1])\n", 354 | "\n", 355 | " te_scores_list = []\n", 356 | " hint_maps = []\n", 357 | " fit_loss = 0\n", 358 | " for j,te in enumerate(teacher_models):\n", 359 | " te.eval()\n", 360 | " with torch.no_grad():\n", 361 | " t_b1, t_b2, t_b3, t_pool, t_output = te(input)\n", 362 | "# print('t_b1:{}, t_b2:{}, t_b3:{}, t_pool:{}'.format(t_b1.size(), t_b2.size(), t_b3.size(), t_pool.size()))\n", 363 | "# t_b1:torch.Size([128, 16, 32, 32]), t_b2:torch.Size([128, 32, 16, 16]), t_b3:torch.Size([128, 64, 8, 8]), t_pool:torch.Size([128, 64, 1, 1])\n", 364 | " hint_maps.append(t_pool)\n", 365 | " t_output = F.softmax(t_output/args.T)\n", 366 | " te_scores_list.append(t_output)\n", 367 | " te_scores_Tensor = torch.stack(te_scores_list, dim=1) # size: [128, 3, 10]\n", 368 | " \n", 369 | " weight = adapter.forward(pool, te_scores_Tensor)\n", 370 | " weight_t = torch.unsqueeze(weight, dim=2) # [128, 3, 1]\n", 371 | " weighted_logits = weight_t * te_scores_Tensor # [128, 3, 10]\n", 372 | " weighted_logits = torch.sum(weighted_logits, dim=1)\n", 373 | " \n", 374 | " optimizer_sgd.zero_grad()\n", 375 | " optimizer_W.zero_grad()\n", 376 | " optimizer_theta.zero_grad()\n", 377 | " \n", 378 | " angle_loss = angle_criterion(output, weighted_logits)\n", 379 | " dist_loss = dist_criterion(output, weighted_logits)\n", 380 | " # compute gradient and do SGD step\n", 381 | " ada_loss = adapter.loss(output, target, weighted_logits, T=args.T, alpha=args.kd_ratio)\n", 382 | " \n", 383 | " for j in range(len(teacher_models)):\n", 384 | " fit_loss += fitnet_criterion[j](st_maps[j+1], hint_maps[j])\n", 385 | "# fit_loss = fitnet_criterion[0](b2, hint_maps[0][3]) + fitnet_criterion[1](b3, hint_maps[1][3]) + fitnet_criterion(pool, hint_maps[2][3])\n", 386 | " loss = ada_loss + dist_loss + angle_loss + fit_loss\n", 387 | " \n", 388 | " loss.backward(retain_graph=True)\n", 389 | " optimizer_sgd.step()\n", 390 | " optimizer_W.step()\n", 391 | " optimizer_theta.step()\n", 392 | "# vis.line(np.array([loss.item()]), np.array([ep]), loss_win, update=\"append\")\n", 393 | " log_out('epoch[{}/{}]adapter Loss: {:.4f}'.format(ep, n_epochs, loss.item()))\n", 394 | " end_time = time.time()\n", 395 | " log_out(\"--- adapter training cost {:.3f} mins ---\".format((end_time - start_time)/60))\n", 396 | " return torch.mean(weight_t, dim=0)\n", 397 | "\n", 398 | "# train with multi-teacher\n", 399 | "def train(epoch, model, te_weights):\n", 400 | " print('Training:')\n", 401 | " # switch to train mode\n", 402 | " model.train()\n", 403 | "\n", 404 | " batch_time = AverageMeter()\n", 405 | " data_time = AverageMeter()\n", 406 | " losses = AverageMeter()\n", 407 | " top1 = AverageMeter()\n", 408 | " \n", 409 | " end = time.time()\n", 410 | " for i, (input, target) in enumerate(train_loader):\n", 411 | "\n", 412 | " # measure data loading time\n", 413 | " data_time.update(time.time() - end)\n", 414 | "\n", 415 | " input, target = input.to(device), target.to(device)\n", 416 | " \n", 417 | " # compute outputs\n", 418 | " b1, b2, b3, pool, output = model(input)\n", 419 | " st_maps = [b1, b2, b3, pool]\n", 420 | " \n", 421 | " te_scores_list = []\n", 422 | " hint_maps = []\n", 423 | " fit_loss = 0\n", 424 | " for j,te in enumerate(teacher_models):\n", 425 | " te.eval()\n", 426 | " with torch.no_grad():\n", 427 | " t_b1, t_b2, t_b3, t_pool, t_output = te(input)\n", 428 | " \n", 429 | " hint_maps.append(t_pool)\n", 430 | " t_output = F.softmax(t_output/args.T)\n", 431 | " \n", 432 | " te_scores_list.append(t_output)\n", 433 | " te_scores_Tensor = torch.stack(te_scores_list, dim=1) # size: [128, 3, 10]\n", 434 | " weighted_logits = te_scores_Tensor * te_weights\n", 435 | " weighted_logits = torch.sum(weighted_logits, dim=1)\n", 436 | " \n", 437 | " optimizer_sgd.zero_grad()\n", 438 | " \n", 439 | " angle_loss = angle_criterion(output, weighted_logits)\n", 440 | " dist_loss = dist_criterion(output, weighted_logits)\n", 441 | " \n", 442 | " # compute gradient and do SGD step\n", 443 | " kd_loss = kd_criterion(output, target, weighted_logits, T=args.T, alpha=args.kd_ratio)\n", 444 | " for j in range(len(teacher_models)):\n", 445 | " fit_loss += fitnet_criterion[j](st_maps[j+1], hint_maps[j])\n", 446 | " \n", 447 | " loss = kd_loss # + dist_loss + angle_loss + fit_loss\n", 448 | "\n", 449 | " loss.backward(retain_graph=True)\n", 450 | " optimizer_sgd.step()\n", 451 | "\n", 452 | " output = output.float()\n", 453 | " loss = loss.float()\n", 454 | " # measure accuracy and record loss\n", 455 | " train_acc = accuracy(output.data, target.data)[0]\n", 456 | " losses.update(loss.item(), input.size(0))\n", 457 | " top1.update(train_acc, input.size(0))\n", 458 | "\n", 459 | " # measure elapsed time\n", 460 | " batch_time.update(time.time() - end)\n", 461 | " end = time.time()\n", 462 | "\n", 463 | " if i % args.print_freq == 0:\n", 464 | " log_out('[{0}/{1}]\\t'\n", 465 | " 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n", 466 | " 'Data {data_time.val:.3f} ({data_time.avg:.3f})\\t'\n", 467 | " 'Loss {loss.val:.4f} ({loss.avg:.4f})\\t'\n", 468 | " 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(\n", 469 | " i, len(train_loader), batch_time=batch_time,\n", 470 | " data_time=data_time, loss=losses, top1=top1))\n", 471 | " return losses.avg, train_acc.cpu().numpy()\n", 472 | "\n", 473 | "\n", 474 | "def test(model):\n", 475 | " print('Testing:')\n", 476 | " # switch to evaluate mode\n", 477 | " model.eval()\n", 478 | " batch_time = AverageMeter()\n", 479 | " losses = AverageMeter()\n", 480 | " top1 = AverageMeter()\n", 481 | "\n", 482 | " end = time.time()\n", 483 | " with torch.no_grad():\n", 484 | " for i, (input, target) in enumerate(test_loader):\n", 485 | " input, target = input.to(device), target.to(device)\n", 486 | "\n", 487 | " # compute output\n", 488 | " _,_,_,_,output = model(input)\n", 489 | " loss = F.cross_entropy(output, target)\n", 490 | "\n", 491 | " output = output.float()\n", 492 | " loss = loss.float()\n", 493 | "\n", 494 | " # measure accuracy and record loss\n", 495 | " test_acc = accuracy(output.data, target.data)[0]\n", 496 | " losses.update(loss.item(), input.size(0))\n", 497 | " top1.update(test_acc, input.size(0))\n", 498 | "\n", 499 | " # measure elapsed time\n", 500 | " batch_time.update(time.time() - end)\n", 501 | " end = time.time()\n", 502 | "\n", 503 | " if i % args.print_freq == 0:\n", 504 | " log_out('Test: [{0}/{1}]\\t'\n", 505 | " 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\\t'\n", 506 | " 'Loss {loss.val:.4f} ({loss.avg:.4f})\\t'\n", 507 | " 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(\n", 508 | " i, len(test_loader), batch_time=batch_time, loss=losses,\n", 509 | " top1=top1))\n", 510 | "\n", 511 | " log_out(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))\n", 512 | "\n", 513 | " return losses.avg, test_acc.cpu().numpy(), top1.avg.cpu().numpy()\n", 514 | "\n", 515 | "# \"\"\"\n", 516 | "print('StudentNet:\\n')\n", 517 | "print(st_model)\n", 518 | "st_model.apply(weights_init_normal)\n", 519 | "weights = teacher_weights(n_epochs=50)\n", 520 | "weights = weights.detach()\n", 521 | "log_out('------------weight:{}'.format(weights))\n", 522 | "# st_model.apply(weights_init_normal)\n", 523 | "best_acc = 0\n", 524 | "for epoch in range(1, args.epochs + 1):\n", 525 | " log_out(\"\\n===> epoch: {}/{}\".format(epoch, args.epochs))\n", 526 | " log_out('current lr {:.5e}'.format(optimizer_sgd.param_groups[0]['lr']))\n", 527 | " lr_scheduler.step(epoch)\n", 528 | " train_loss, train_acc = train(epoch, st_model, weights)\n", 529 | " # visaulize loss\n", 530 | " vis.line(np.array([train_loss]), np.array([epoch]), loss_win, update=\"append\")\n", 531 | " _, test_acc, top1 = test(st_model)\n", 532 | " vis.line(np.column_stack((train_acc, top1)), np.column_stack((epoch, epoch)), acc_win, update=\"append\")\n", 533 | " if top1 > best_acc:\n", 534 | " best_acc = top1\n", 535 | " \n", 536 | "# release GPU memory\n", 537 | "torch.cuda.empty_cache()\n", 538 | "log_out(\"BEST ACC: {:.3f}\".format(best_acc))\n", 539 | "log_out(\"--- {:.3f} mins ---\".format((time.time() - start_time)/60))\n", 540 | "# \"\"\"" 541 | ] 542 | } 543 | ], 544 | "metadata": { 545 | "kernelspec": { 546 | "display_name": "Python [conda env:py37] *", 547 | "language": "python", 548 | "name": "conda-env-py37-py" 549 | } 550 | }, 551 | "nbformat": 4, 552 | "nbformat_minor": 2 553 | } 554 | --------------------------------------------------------------------------------