├── .gitignore ├── CIFAR-100 ├── README.md ├── distiller.py ├── load_settings.py ├── models │ ├── PyramidNet.py │ ├── ResNet.py │ └── WideResNet.py └── train_with_distillation.py ├── ImageNet ├── README.md ├── distiller.py ├── models │ ├── MobileNet.py │ └── ResNet.py ├── train_with_distillation.py └── utils.py ├── LICENSE ├── NOTICE ├── README.md ├── Segmentation ├── README.md ├── dataloaders │ ├── __init__.py │ ├── custom_transforms.py │ ├── datasets │ │ ├── __init__.py │ │ ├── cityscapes.py │ │ ├── coco.py │ │ ├── combine_dbs.py │ │ ├── pascal.py │ │ └── sbd.py │ └── utils.py ├── distiller.py ├── modeling │ ├── __init__.py │ ├── aspp.py │ ├── backbone │ │ ├── __init__.py │ │ ├── drn.py │ │ ├── mobilenet.py │ │ ├── resnet.py │ │ └── xception.py │ ├── decoder.py │ ├── deeplab.py │ └── sync_batchnorm │ │ ├── __init__.py │ │ ├── batchnorm.py │ │ ├── comm.py │ │ ├── replicate.py │ │ └── unittest.py ├── mypath.py ├── pretrained │ └── .gitignore ├── train.py ├── train_voc.sh ├── train_with_distillation.py └── utils │ ├── __init__.py │ ├── calculate_weights.py │ ├── loss.py │ ├── lr_scheduler.py │ ├── metrics.py │ ├── saver.py │ └── summaries.py └── data └── .gitignore /.gitignore: -------------------------------------------------------------------------------- 1 | .idea -------------------------------------------------------------------------------- /CIFAR-100/README.md: -------------------------------------------------------------------------------- 1 | ## CIFAR-100 2 | 3 | ### Settings 4 | We provide the code of the experimental settings specified in the paper. 5 | 6 | | Setup | Compression type | Teacher | Student | Teacher size | Student size | Size ratio | 7 | |:-----:|:----------------:|:-----------:|:-----------:|:------------:|:------------:|:----------:| 8 | | (a) | Depth | WRN 28-4 | WRN 16-4 | 5.87M | 2.77M | 47.2% | 9 | | (b) | Channel | WRN 28-4 | WRN 28-2 | 5.87M | 1.47M | 25.0% | 10 | | (c) | Depth & channel | WRN 28-4 | WRN 16-2 | 5.87M | 0.70M | 11.9% | 11 | | (d) | Architecture | WRN 28-4 | ResNet 56 | 5.87M | 0.86M | 14.7% | 12 | | (e) | Architecture | Pyramid-200 | WRN 28-4 | 26.84M | 5.87M | 21.9% | 13 | | (f) | Architecture | Pyramid-200 | Pyramid-110 | 26.84M | 3.91M | 14.6% | 14 | 15 | ### Teacher models 16 | Download following pre-trained teacher network and put them into ```./data``` directory 17 | - [Wide Residual Network 28-4](https://drive.google.com/open?id=1Quxgs5teXVXwD3jBdkk-WeNLNpxbiZXN) 18 | - [PyramidNet-200(240)](https://drive.google.com/open?id=1_QgG81fNM3OvVIbMAxDPykKWuSIyKnmz) 19 | 20 | ### Training 21 | Run ```CIFAR-100/train_with_distillation.py``` with setting alphabet (a - f) 22 | ``` 23 | python train_with_distillation.py \ 24 | --setting a \ 25 | --epochs 200 \ 26 | --batch_size 128 \ 27 | --lr 0.1 \ 28 | --momentum 0.9 \ 29 | --weight_decay 5e-4 30 | ``` 31 | 32 | For pyramid teacher (e, f), we used batch-size 64 to save gpu memory. 33 | ``` 34 | python train_with_distillation.py \ 35 | --setting e \ 36 | --epochs 200 \ 37 | --batch_size 64 \ 38 | --lr 0.1 \ 39 | --momentum 0.9 \ 40 | --weight_decay 5e-4 41 | ``` 42 | 43 | ### Experimental results 44 | 45 | | Setup | Teacher | Student | Original | Proposed | Improvement | 46 | |:-----:|:-----------:|:-----------:|:--------:|:--------:|:-----------:| 47 | | (a) | WRN 28-4 | WRN 16-4 | 22.72% | 20.89% | 1.83% | 48 | | (b) | WRN 28-4 | WRN 28-2 | 24.88% | 21.98% | 2.90% | 49 | | (c) | WRN 28-4 | WRN 16-2 | 27.32% | 24.08% | 3.24% | 50 | | (d) | WRN 28-4 | ResNet 56 | 27.68% | 24.44% | 3.24% | 51 | | (f) | Pyramid-200 | WRN 28-4 | 21.09% | 17.80% | 3.29% | 52 | | (g) | Pyramid-200 | Pyramid-110 | 22.58% | 18.89% | 3.69% | 53 | -------------------------------------------------------------------------------- /CIFAR-100/distiller.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from scipy.stats import norm 5 | import scipy 6 | 7 | import math 8 | 9 | def distillation_loss(source, target, margin): 10 | target = torch.max(target, margin) 11 | loss = torch.nn.functional.mse_loss(source, target, reduction="none") 12 | loss = loss * ((source > target) | (target > 0)).float() 13 | return loss.sum() 14 | 15 | def build_feature_connector(t_channel, s_channel): 16 | C = [nn.Conv2d(s_channel, t_channel, kernel_size=1, stride=1, padding=0, bias=False), 17 | nn.BatchNorm2d(t_channel)] 18 | 19 | for m in C: 20 | if isinstance(m, nn.Conv2d): 21 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 22 | m.weight.data.normal_(0, math.sqrt(2. / n)) 23 | elif isinstance(m, nn.BatchNorm2d): 24 | m.weight.data.fill_(1) 25 | m.bias.data.zero_() 26 | 27 | return nn.Sequential(*C) 28 | 29 | def get_margin_from_BN(bn): 30 | margin = [] 31 | std = bn.weight.data 32 | mean = bn.bias.data 33 | for (s, m) in zip(std, mean): 34 | s = abs(s.item()) 35 | m = m.item() 36 | if norm.cdf(-m / s) > 0.001: 37 | margin.append(- s * math.exp(- (m / s) ** 2 / 2) / math.sqrt(2 * math.pi) / norm.cdf(-m / s) + m) 38 | else: 39 | margin.append(-3 * s) 40 | 41 | return torch.FloatTensor(margin).to(std.device) 42 | 43 | class Distiller(nn.Module): 44 | def __init__(self, t_net, s_net): 45 | super(Distiller, self).__init__() 46 | 47 | t_channels = t_net.get_channel_num() 48 | s_channels = s_net.get_channel_num() 49 | 50 | self.Connectors = nn.ModuleList([build_feature_connector(t, s) for t, s in zip(t_channels, s_channels)]) 51 | 52 | teacher_bns = t_net.get_bn_before_relu() 53 | margins = [get_margin_from_BN(bn) for bn in teacher_bns] 54 | for i, margin in enumerate(margins): 55 | self.register_buffer('margin%d' % (i+1), margin.unsqueeze(1).unsqueeze(2).unsqueeze(0).detach()) 56 | 57 | self.t_net = t_net 58 | self.s_net = s_net 59 | 60 | def forward(self, x): 61 | 62 | t_feats, t_out = self.t_net.extract_feature(x, preReLU=True) 63 | s_feats, s_out = self.s_net.extract_feature(x, preReLU=True) 64 | feat_num = len(t_feats) 65 | 66 | loss_distill = 0 67 | for i in range(feat_num): 68 | s_feats[i] = self.Connectors[i](s_feats[i]) 69 | loss_distill += distillation_loss(s_feats[i], t_feats[i].detach(), getattr(self, 'margin%d' % (i+1))) \ 70 | / 2 ** (feat_num - i - 1) 71 | 72 | return s_out, loss_distill 73 | -------------------------------------------------------------------------------- /CIFAR-100/load_settings.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import os 4 | 5 | import models.WideResNet as WRN 6 | import models.PyramidNet as PYN 7 | import models.ResNet as RN 8 | 9 | def load_paper_settings(args): 10 | 11 | WRN_path = os.path.join(args.data_path, 'WRN28-4_21.09.pt') 12 | Pyramid_path = os.path.join(args.data_path, 'pyramid200_mixup_15.6.tar') 13 | 14 | if args.paper_setting == 'a': 15 | teacher = WRN.WideResNet(depth=28, widen_factor=4, num_classes=100) 16 | state = torch.load(WRN_path, map_location={'cuda:0': 'cpu'})['model'] 17 | teacher.load_state_dict(state) 18 | student = WRN.WideResNet(depth=16, widen_factor=4, num_classes=100) 19 | 20 | elif args.paper_setting == 'b': 21 | teacher = WRN.WideResNet(depth=28, widen_factor=4, num_classes=100) 22 | state = torch.load(WRN_path, map_location={'cuda:0': 'cpu'})['model'] 23 | teacher.load_state_dict(state) 24 | student = WRN.WideResNet(depth=28, widen_factor=2, num_classes=100) 25 | 26 | elif args.paper_setting == 'c': 27 | teacher = WRN.WideResNet(depth=28, widen_factor=4, num_classes=100) 28 | state = torch.load(WRN_path, map_location={'cuda:0': 'cpu'})['model'] 29 | teacher.load_state_dict(state) 30 | student = WRN.WideResNet(depth=16, widen_factor=2, num_classes=100) 31 | 32 | elif args.paper_setting == 'd': 33 | teacher = WRN.WideResNet(depth=28, widen_factor=4, num_classes=100) 34 | state = torch.load(WRN_path, map_location={'cuda:0': 'cpu'})['model'] 35 | teacher.load_state_dict(state) 36 | student = RN.ResNet(depth=56, num_classes=100) 37 | 38 | elif args.paper_setting == 'e': 39 | teacher = PYN.PyramidNet(depth=200, alpha=240, num_classes=100, bottleneck=True) 40 | state = torch.load(Pyramid_path, map_location={'cuda:0': 'cpu'})['state_dict'] 41 | from collections import OrderedDict 42 | new_state = OrderedDict() 43 | for k, v in state.items(): 44 | name = k[7:] # remove 'module.' of dataparallel 45 | new_state[name] = v 46 | teacher.load_state_dict(new_state) 47 | student = WRN.WideResNet(depth=28, widen_factor=4, num_classes=100) 48 | 49 | elif args.paper_setting == 'f': 50 | teacher = PYN.PyramidNet(depth=200, alpha=240, num_classes=100, bottleneck=True) 51 | state = torch.load(Pyramid_path, map_location={'cuda:0': 'cpu'})['state_dict'] 52 | from collections import OrderedDict 53 | new_state = OrderedDict() 54 | for k, v in state.items(): 55 | name = k[7:] # remove 'module.' of dataparallel 56 | new_state[name] = v 57 | teacher.load_state_dict(new_state) 58 | student = PYN.PyramidNet(depth=110, alpha=84, num_classes=100, bottleneck=False) 59 | 60 | else: 61 | print('Undefined setting name !!!') 62 | 63 | return teacher, student, args -------------------------------------------------------------------------------- /CIFAR-100/models/PyramidNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | #from math import round 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | "3x3 convolution with padding" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | outchannel_ratio = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(inplanes) 19 | self.conv1 = conv3x3(inplanes, planes, stride) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn3 = nn.BatchNorm2d(planes) 23 | self.relu = nn.ReLU(inplace=True) 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x): 28 | 29 | out = self.bn1(x) 30 | out = self.conv1(out) 31 | out = self.bn2(out) 32 | out = self.relu(out) 33 | out = self.conv2(out) 34 | out = self.bn3(out) 35 | if self.downsample is not None: 36 | shortcut = self.downsample(x) 37 | featuremap_size = shortcut.size()[2:4] 38 | else: 39 | shortcut = x 40 | featuremap_size = out.size()[2:4] 41 | 42 | batch_size = out.size()[0] 43 | residual_channel = out.size()[1] 44 | shortcut_channel = shortcut.size()[1] 45 | 46 | if residual_channel != shortcut_channel: 47 | padding = torch.autograd.Variable(torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], featuremap_size[1]).fill_(0)) 48 | out += torch.cat((shortcut, padding), 1) 49 | else: 50 | out += shortcut 51 | 52 | return out 53 | 54 | 55 | class Bottleneck(nn.Module): 56 | outchannel_ratio = 4 57 | 58 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16): 59 | super(Bottleneck, self).__init__() 60 | self.bn1 = nn.BatchNorm2d(inplanes) 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn2 = nn.BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, (planes), kernel_size=3, stride=stride, padding=1, bias=False, groups=1) 64 | self.bn3 = nn.BatchNorm2d((planes)) 65 | self.conv3 = nn.Conv2d((planes), planes * Bottleneck.outchannel_ratio, kernel_size=1, bias=False) 66 | self.bn4 = nn.BatchNorm2d(planes * Bottleneck.outchannel_ratio) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | 74 | out = self.bn1(x) 75 | out = self.conv1(out) 76 | 77 | out = self.bn2(out) 78 | out = self.relu(out) 79 | out = self.conv2(out) 80 | 81 | out = self.bn3(out) 82 | out = self.relu(out) 83 | out = self.conv3(out) 84 | 85 | out = self.bn4(out) 86 | if self.downsample is not None: 87 | shortcut = self.downsample(x) 88 | featuremap_size = shortcut.size()[2:4] 89 | else: 90 | shortcut = x 91 | featuremap_size = out.size()[2:4] 92 | 93 | batch_size = out.size()[0] 94 | residual_channel = out.size()[1] 95 | shortcut_channel = shortcut.size()[1] 96 | 97 | if residual_channel != shortcut_channel: 98 | padding = torch.autograd.Variable(torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], featuremap_size[1]).fill_(0)) 99 | out += torch.cat((shortcut, padding), 1) 100 | else: 101 | out += shortcut 102 | 103 | return out 104 | 105 | 106 | class PyramidNet(nn.Module): 107 | 108 | def __init__(self, depth, alpha, num_classes, bottleneck=False): 109 | super(PyramidNet, self).__init__() 110 | self.inplanes = 16 111 | if bottleneck == True: 112 | n = int((depth - 2) / 9) 113 | block = Bottleneck 114 | else: 115 | n = int((depth - 2) / 6) 116 | block = BasicBlock 117 | 118 | self.addrate = alpha / (3*n*1.0) 119 | 120 | self.input_featuremap_dim = self.inplanes 121 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=3, stride=1, padding=1, bias=False) 122 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim) 123 | 124 | self.featuremap_dim = self.input_featuremap_dim 125 | self.layer1 = self.pyramidal_make_layer(block, n) 126 | self.layer2 = self.pyramidal_make_layer(block, n, stride=2) 127 | self.layer3 = self.pyramidal_make_layer(block, n, stride=2) 128 | 129 | self.final_featuremap_dim = self.input_featuremap_dim 130 | self.bn_final= nn.BatchNorm2d(self.final_featuremap_dim) 131 | self.relu_final = nn.ReLU(inplace=True) 132 | self.avgpool = nn.AvgPool2d(8) 133 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes) 134 | 135 | for m in self.modules(): 136 | if isinstance(m, nn.Conv2d): 137 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 138 | m.weight.data.normal_(0, math.sqrt(2. / n)) 139 | elif isinstance(m, nn.BatchNorm2d): 140 | m.weight.data.fill_(1) 141 | m.bias.data.zero_() 142 | 143 | def pyramidal_make_layer(self, block, block_depth, stride=1): 144 | downsample = None 145 | if stride != 1: # or self.inplanes != int(round(featuremap_dim_1st)) * block.outchannel_ratio: 146 | downsample = nn.AvgPool2d((2,2), stride = (2, 2), ceil_mode=True) 147 | 148 | layers = [] 149 | self.featuremap_dim = self.featuremap_dim + self.addrate 150 | layers.append(block(self.input_featuremap_dim, int(round(self.featuremap_dim)), stride, downsample)) 151 | for i in range(1, block_depth): 152 | temp_featuremap_dim = self.featuremap_dim + self.addrate 153 | layers.append(block(int(round(self.featuremap_dim)) * block.outchannel_ratio, int(round(temp_featuremap_dim)), 1)) 154 | self.featuremap_dim = temp_featuremap_dim 155 | self.input_featuremap_dim = int(round(self.featuremap_dim)) * block.outchannel_ratio 156 | 157 | return nn.Sequential(*layers) 158 | 159 | def forward(self, x): 160 | x = self.conv1(x) 161 | x = self.bn1(x) 162 | 163 | x = self.layer1(x) 164 | x = self.layer2(x) 165 | x = self.layer3(x) 166 | 167 | x = self.bn_final(x) 168 | x = self.relu_final(x) 169 | x = self.avgpool(x) 170 | x = x.view(x.size(0), -1) 171 | x = self.fc(x) 172 | return x 173 | 174 | def get_bn_before_relu(self): 175 | bn1 = self.layer2[0].bn2 176 | bn2 = self.layer3[0].bn2 177 | bn3 = self.bn_final 178 | 179 | return [bn1, bn2, bn3] 180 | 181 | def get_channel_num(self): 182 | 183 | if isinstance(self.layer1[0], Bottleneck): 184 | nChannel1 = self.layer2[0].conv1.out_channels 185 | nChannel2 = self.layer3[0].conv1.out_channels 186 | nChannel3 = self.final_featuremap_dim 187 | elif isinstance(self.layer1[0], BasicBlock): 188 | nChannel1 = self.layer2[0].conv1.in_channels 189 | nChannel2 = self.layer3[0].conv1.in_channels 190 | nChannel3 = self.final_featuremap_dim 191 | else: 192 | print('PyramidNet unknown block error !!!') 193 | 194 | return [nChannel1, nChannel2, nChannel3] 195 | 196 | def extract_feature(self, x, preReLU=False): 197 | x = self.conv1(x) 198 | x = self.bn1(x) 199 | 200 | feat1 = self.layer1(x) 201 | feat2 = self.layer2(feat1) 202 | feat3 = self.layer3(feat2) 203 | 204 | x = self.bn_final(feat3) 205 | x = self.relu_final(x) 206 | x = self.avgpool(x) 207 | x = x.view(x.size(0), -1) 208 | out = self.fc(x) 209 | 210 | if preReLU: 211 | if isinstance(self.layer1[0], Bottleneck): 212 | l = self.layer2[0] 213 | feat1 = l.bn2(l.conv1(l.bn1(feat1))) 214 | l = self.layer3[0] 215 | feat2 = l.bn2(l.conv1(l.bn1(feat2))) 216 | feat3 = self.bn_final(feat3) 217 | elif isinstance(self.layer1[0], BasicBlock): 218 | feat1 = self.layer2[0].bn1(feat1) 219 | feat2 = self.layer3[0].bn1(feat2) 220 | feat3 = self.bn_final(feat3) 221 | else: 222 | print('PyramidNet unknown block error !!!') 223 | 224 | return [feat1, feat2, feat3], out 225 | -------------------------------------------------------------------------------- /CIFAR-100/models/ResNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch.nn.functional as F 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | "3x3 convolution with padding" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.conv2 = conv3x3(planes, planes) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | self.relu = nn.ReLU(inplace=True) 23 | 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x): 28 | x = F.relu(x) 29 | residual = x 30 | 31 | out = self.conv1(x) 32 | out = self.bn1(out) 33 | out = self.relu(out) 34 | 35 | out = self.conv2(out) 36 | out = self.bn2(out) 37 | 38 | if self.downsample is not None: 39 | residual = self.downsample(x) 40 | 41 | out += residual 42 | # out = self.relu(out) 43 | 44 | return out 45 | 46 | 47 | class Bottleneck(nn.Module): 48 | expansion = 4 49 | 50 | def __init__(self, inplanes, planes, stride=1, downsample=None): 51 | super(Bottleneck, self).__init__() 52 | 53 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 54 | self.bn1 = nn.BatchNorm2d(planes) 55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 56 | self.bn2 = nn.BatchNorm2d(planes) 57 | self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion, kernel_size=1, bias=False) 58 | self.bn3 = nn.BatchNorm2d(planes * Bottleneck.expansion) 59 | self.relu = nn.ReLU(inplace=True) 60 | 61 | self.downsample = downsample 62 | self.stride = stride 63 | 64 | def forward(self, x): 65 | x = F.relu(x) 66 | residual = x 67 | 68 | out = self.conv1(x) 69 | out = self.bn1(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv2(out) 73 | out = self.bn2(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv3(out) 77 | out = self.bn3(out) 78 | if self.downsample is not None: 79 | residual = self.downsample(x) 80 | 81 | out += residual 82 | # out = self.relu(out) 83 | 84 | return out 85 | 86 | 87 | class ResNet(nn.Module): 88 | def __init__(self, depth, num_classes, bottleneck=False): 89 | super(ResNet, self).__init__() 90 | self.inplanes = 16 91 | print(bottleneck) 92 | if bottleneck == True: 93 | n = int((depth - 2) / 9) 94 | block = Bottleneck 95 | else: 96 | n = int((depth - 2) / 6) 97 | block = BasicBlock 98 | 99 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 100 | self.bn1 = nn.BatchNorm2d(self.inplanes) 101 | self.relu = nn.ReLU(inplace=True) 102 | self.layer1 = self._make_layer(block, 16, n) 103 | self.layer2 = self._make_layer(block, 32, n, stride=2) 104 | self.layer3 = self._make_layer(block, 64, n, stride=2) 105 | self.avgpool = nn.AvgPool2d(8) 106 | self.fc = nn.Linear(64 * block.expansion, num_classes) 107 | 108 | for m in self.modules(): 109 | if isinstance(m, nn.Conv2d): 110 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 111 | m.weight.data.normal_(0, math.sqrt(2. / n)) 112 | elif isinstance(m, nn.BatchNorm2d): 113 | m.weight.data.fill_(1) 114 | m.bias.data.zero_() 115 | 116 | def _make_layer(self, block, planes, blocks, stride=1): 117 | downsample = None 118 | if stride != 1 or self.inplanes != planes * block.expansion: 119 | downsample = nn.Sequential( 120 | nn.Conv2d(self.inplanes, planes * block.expansion, 121 | kernel_size=1, stride=stride, bias=False), 122 | nn.BatchNorm2d(planes * block.expansion), 123 | ) 124 | 125 | layers = [] 126 | layers.append(block(self.inplanes, planes, stride, downsample)) 127 | self.inplanes = planes * block.expansion 128 | for i in range(1, blocks): 129 | layers.append(block(self.inplanes, planes)) 130 | 131 | return nn.Sequential(*layers) 132 | 133 | def forward(self, x): 134 | 135 | x = self.conv1(x) 136 | x = self.bn1(x) 137 | 138 | x = self.layer1(x) 139 | x = self.layer2(x) 140 | x = self.layer3(x) 141 | 142 | x = F.relu(x) 143 | x = self.avgpool(x) 144 | x = x.view(x.size(0), -1) 145 | x = self.fc(x) 146 | 147 | return x 148 | 149 | def get_bn_before_relu(self): 150 | if isinstance(self.layer1[0], Bottleneck): 151 | bn1 = self.layer1[-1].bn3 152 | bn2 = self.layer2[-1].bn3 153 | bn3 = self.layer3[-1].bn3 154 | elif isinstance(self.layer1[0], BasicBlock): 155 | bn1 = self.layer1[-1].bn2 156 | bn2 = self.layer2[-1].bn2 157 | bn3 = self.layer3[-1].bn2 158 | else: 159 | print('ResNet unknown block error !!!') 160 | 161 | return [bn1, bn2, bn3] 162 | 163 | def get_channel_num(self): 164 | 165 | return [16, 32, 64] 166 | 167 | def extract_feature(self, x, preReLU=False): 168 | 169 | x = self.conv1(x) 170 | x = self.bn1(x) 171 | 172 | feat1 = self.layer1(x) 173 | feat2 = self.layer2(feat1) 174 | feat3 = self.layer3(feat2) 175 | 176 | x = F.relu(feat3) 177 | x = self.avgpool(x) 178 | x = x.view(x.size(0), -1) 179 | out = self.fc(x) 180 | 181 | if not preReLU: 182 | feat1 = F.relu(feat1) 183 | feat2 = F.relu(feat2) 184 | feat3 = F.relu(feat3) 185 | 186 | return [feat1, feat2, feat3], out 187 | -------------------------------------------------------------------------------- /CIFAR-100/models/WideResNet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0): 9 | super(BasicBlock, self).__init__() 10 | self.bn1 = nn.BatchNorm2d(in_planes) 11 | self.relu1 = nn.ReLU(inplace=True) 12 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 13 | padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(out_planes) 15 | self.relu2 = nn.ReLU(inplace=True) 16 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 17 | padding=1, bias=False) 18 | self.droprate = dropRate 19 | self.equalInOut = (in_planes == out_planes) 20 | self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 21 | padding=0, bias=False) or None 22 | 23 | def forward(self, x): 24 | if not self.equalInOut: 25 | x = self.relu1(self.bn1(x)) 26 | else: 27 | out = self.relu1(self.bn1(x)) 28 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 29 | if self.droprate > 0: 30 | out = F.dropout(out, p=self.droprate, training=self.training) 31 | out = self.conv2(out) 32 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 33 | 34 | class NetworkBlock(nn.Module): 35 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 36 | super(NetworkBlock, self).__init__() 37 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 38 | 39 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 40 | layers = [] 41 | for i in range(int(nb_layers)): 42 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 43 | return nn.Sequential(*layers) 44 | 45 | def forward(self, x): 46 | return self.layer(x) 47 | 48 | class WideResNet(nn.Module): 49 | def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0): 50 | super(WideResNet, self).__init__() 51 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 52 | assert((depth - 4) % 6 == 0) 53 | n = (depth - 4) / 6 54 | block = BasicBlock 55 | # 1st conv before any network block 56 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False) 57 | # 1st block 58 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 59 | # 2nd block 60 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 61 | # 3rd block 62 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 63 | # global average pooling and classifier 64 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 65 | self.relu = nn.ReLU(inplace=True) 66 | self.fc = nn.Linear(nChannels[3], num_classes) 67 | self.nChannels = nChannels 68 | 69 | for m in self.modules(): 70 | if isinstance(m, nn.Conv2d): 71 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 72 | m.weight.data.normal_(0, math.sqrt(2. / n)) 73 | elif isinstance(m, nn.BatchNorm2d): 74 | m.weight.data.fill_(1) 75 | m.bias.data.zero_() 76 | elif isinstance(m, nn.Linear): 77 | m.bias.data.zero_() 78 | 79 | def forward(self, x): 80 | out = self.conv1(x) 81 | out = self.block1(out) 82 | out = self.block2(out) 83 | out = self.block3(out) 84 | out = self.relu(self.bn1(out)) 85 | out = F.avg_pool2d(out, 8) 86 | out = out.view(-1, self.nChannels[3]) 87 | return self.fc(out) 88 | 89 | def get_bn_before_relu(self): 90 | bn1 = self.block2.layer[0].bn1 91 | bn2 = self.block3.layer[0].bn1 92 | bn3 = self.bn1 93 | 94 | return [bn1, bn2, bn3] 95 | 96 | def get_channel_num(self): 97 | 98 | return self.nChannels[1:] 99 | 100 | def extract_feature(self, x, preReLU=False): 101 | out = self.conv1(x) 102 | feat1 = self.block1(out) 103 | feat2 = self.block2(feat1) 104 | feat3 = self.block3(feat2) 105 | out = self.relu(self.bn1(feat3)) 106 | out = F.avg_pool2d(out, 8) 107 | out = out.view(-1, self.nChannels[3]) 108 | out = self.fc(out) 109 | 110 | if preReLU: 111 | feat1 = self.block2.layer[0].bn1(feat1) 112 | feat2 = self.block3.layer[0].bn1(feat2) 113 | feat3 = self.bn1(feat3) 114 | 115 | return [feat1, feat2, feat3], out 116 | 117 | -------------------------------------------------------------------------------- /CIFAR-100/train_with_distillation.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import time 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim as optim 10 | import torchvision 11 | import torchvision.transforms as transforms 12 | import torch.nn.functional as F 13 | import argparse 14 | 15 | import distiller 16 | import load_settings 17 | 18 | parser = argparse.ArgumentParser(description='CIFAR-100 training') 19 | parser.add_argument('--data_path', type=str, default='../data') 20 | parser.add_argument('--paper_setting', default='a', type=str) 21 | parser.add_argument('--epochs', default=200, type=int, help='number of total epochs to run') 22 | parser.add_argument('--batch_size', default=128, type=int, help='mini-batch size (default: 256)') 23 | parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') 24 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 25 | parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight decay') 26 | args = parser.parse_args() 27 | 28 | gpu_num = 0 29 | use_cuda = torch.cuda.is_available() 30 | transform_train = transforms.Compose([ 31 | transforms.Pad(4, padding_mode='reflect'), 32 | transforms.RandomHorizontalFlip(), 33 | transforms.RandomCrop(32), 34 | transforms.ToTensor(), 35 | transforms.Normalize(np.array([125.3, 123.0, 113.9]) / 255.0, 36 | np.array([63.0, 62.1, 66.7]) / 255.0) 37 | ]) 38 | 39 | transform_test = transforms.Compose([ 40 | transforms.ToTensor(), 41 | transforms.Normalize(np.array([125.3, 123.0, 113.9]) / 255.0, 42 | np.array([63.0, 62.1, 66.7]) / 255.0), 43 | ]) 44 | 45 | trainset = torchvision.datasets.CIFAR100(root=args.data_path, train=True, download=True, transform=transform_train) 46 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4) 47 | testset = torchvision.datasets.CIFAR100(root=args.data_path, train=False, download=True, transform=transform_test) 48 | testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=4) 49 | 50 | # Model 51 | t_net, s_net, args = load_settings.load_paper_settings(args) 52 | 53 | # Module for distillation 54 | d_net = distiller.Distiller(t_net, s_net) 55 | 56 | print('the number of teacher model parameters: {}'.format(sum([p.data.nelement() for p in t_net.parameters()]))) 57 | print('the number of student model parameters: {}'.format(sum([p.data.nelement() for p in s_net.parameters()]))) 58 | 59 | if use_cuda: 60 | torch.cuda.set_device(0) 61 | d_net.cuda() 62 | s_net.cuda() 63 | t_net.cuda() 64 | cudnn.benchmark = True 65 | 66 | criterion_CE = nn.CrossEntropyLoss() 67 | 68 | # Training 69 | def train_with_distill(d_net, epoch): 70 | epoch_start_time = time.time() 71 | print('\nDistillation epoch: %d' % epoch) 72 | 73 | d_net.train() 74 | d_net.s_net.train() 75 | d_net.t_net.train() 76 | 77 | train_loss = 0 78 | correct = 0 79 | total = 0 80 | 81 | global optimizer 82 | for batch_idx, (inputs, targets) in enumerate(trainloader): 83 | if use_cuda: 84 | inputs, targets = inputs.cuda(), targets.cuda() 85 | optimizer.zero_grad() 86 | 87 | batch_size = inputs.shape[0] 88 | outputs, loss_distill = d_net(inputs) 89 | loss_CE = criterion_CE(outputs, targets) 90 | 91 | loss = loss_CE + loss_distill.sum() / batch_size / 1000 92 | 93 | loss.backward() 94 | optimizer.step() 95 | 96 | train_loss += loss_CE.item() 97 | 98 | _, predicted = torch.max(outputs.data, 1) 99 | total += targets.size(0) 100 | correct += predicted.eq(targets.data).cpu().sum().float().item() 101 | 102 | b_idx = batch_idx 103 | 104 | print('Train \t Time Taken: %.2f sec' % (time.time() - epoch_start_time)) 105 | print('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (train_loss / (b_idx + 1), 100. * correct / total, correct, total)) 106 | 107 | return train_loss / (b_idx + 1) 108 | 109 | def test(net): 110 | epoch_start_time = time.time() 111 | net.eval() 112 | test_loss = 0 113 | correct = 0 114 | total = 0 115 | for batch_idx, (inputs, targets) in enumerate(testloader): 116 | if use_cuda: 117 | inputs, targets = inputs.cuda(), targets.cuda() 118 | outputs = net(inputs) 119 | loss = criterion_CE(outputs, targets) 120 | 121 | test_loss += loss.item() 122 | _, predicted = torch.max(outputs.data, 1) 123 | total += targets.size(0) 124 | correct += predicted.eq(targets.data).cpu().sum().float().item() 125 | b_idx = batch_idx 126 | 127 | print('Test \t Time Taken: %.2f sec' % (time.time() - epoch_start_time)) 128 | print('Loss: %.3f | Acc: %.3f%% (%d/%d)' % (test_loss / (b_idx + 1), 100. * correct / total, correct, total)) 129 | return test_loss / (b_idx + 1), correct / total 130 | 131 | 132 | print('Performance of teacher network') 133 | test(t_net) 134 | 135 | for epoch in range(args.epochs): 136 | if epoch is 0: 137 | optimizer = optim.SGD([{'params': s_net.parameters()}, {'params': d_net.Connectors.parameters()}], 138 | lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) 139 | elif epoch is (args.epochs // 2): 140 | optimizer = optim.SGD([{'params': s_net.parameters()}, {'params': d_net.Connectors.parameters()}], 141 | lr=args.lr / 10, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) 142 | elif epoch is (args.epochs * 3 // 4): 143 | optimizer = optim.SGD([{'params': s_net.parameters()}, {'params': d_net.Connectors.parameters()}], 144 | lr=args.lr / 100, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) 145 | 146 | train_loss = train_with_distill(d_net, epoch) 147 | test_loss, accuracy = test(s_net) 148 | -------------------------------------------------------------------------------- /ImageNet/README.md: -------------------------------------------------------------------------------- 1 | ## ImageNet 2 | 3 | ### Settings 4 | 5 | | Setup | Compression type | Teacher | Student | Teacher size | Student size | Size ratio | 6 | |:-----:|:----------------:|:----------:|:---------:|:------------:|:------------:|:----------:| 7 | | (a) | Depth | ResNet 152 | ResNet 50 | 60.19M | 25.56M | 42.47% | 8 | | (b) | Architecture | ResNet 50 | MobileNet | 25.56M | 4.23M | 16.55% | 9 | 10 | 11 | In case of ImageNet, teacher model will be automatically downloaded from PyTorch sites. 12 | 13 | ### Training 14 | 15 | - (a) : ResNet152 to ResNet50 16 | ``` 17 | python train_with_distillation.py \ 18 | --data_path your/path/to/ImageNet \ 19 | --net_type resnet \ 20 | --epochs 100 \ 21 | --lr 0.1 \ 22 | --batch_size 256 23 | ``` 24 | 25 | - (b) : ResNet50 to MobileNet 26 | ``` 27 | python train_with_distillation.py \ 28 | --data_path your/path/to/ImageNet \ 29 | --net_type mobilenet \ 30 | --epochs 100 \ 31 | --lr 0.1 \ 32 | --batch_size 256 33 | ``` 34 | 35 | ### Experimental results 36 | 37 | - ResNet 50 38 | 39 | | Network | Method | Top1-error | Top5-error | 40 | |:----------:|:--------:|:----------:|:----------:| 41 | | ResNet 152 | Teacher | 21.69 | 5.95 | 42 | | ResNet 50 | Original | 23.72 | 6.97 | 43 | | | Proposed | __21.65__ | __5.83__ | 44 | 45 | - MobileNet 46 | 47 | | Network | Method | Top1 | Top5 | 48 | |:---------:|:--------:|:-----:|:-----:| 49 | | ResNet 50 | Teacher | 23.84 | 7.14 | 50 | | Mobilenet | Original | 31.13 | 11.24 | 51 | | | Proposed | __28.75__ | __9.66__ | 52 | -------------------------------------------------------------------------------- /ImageNet/distiller.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from scipy.stats import norm 5 | import scipy 6 | 7 | import math 8 | 9 | def distillation_loss(source, target, margin): 10 | target = torch.max(target, margin) 11 | loss = torch.nn.functional.mse_loss(source, target, reduction="none") 12 | loss = loss * ((source > target) | (target > 0)).float() 13 | return loss.sum() 14 | 15 | def build_feature_connector(t_channel, s_channel): 16 | C = [nn.Conv2d(s_channel, t_channel, kernel_size=1, stride=1, padding=0, bias=False), 17 | nn.BatchNorm2d(t_channel)] 18 | 19 | for m in C: 20 | if isinstance(m, nn.Conv2d): 21 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 22 | m.weight.data.normal_(0, math.sqrt(2. / n)) 23 | elif isinstance(m, nn.BatchNorm2d): 24 | m.weight.data.fill_(1) 25 | m.bias.data.zero_() 26 | 27 | return nn.Sequential(*C) 28 | 29 | def get_margin_from_BN(bn): 30 | margin = [] 31 | std = bn.weight.data 32 | mean = bn.bias.data 33 | for (s, m) in zip(std, mean): 34 | s = abs(s.item()) 35 | m = m.item() 36 | if norm.cdf(-m / s) > 0.001: 37 | margin.append(- s * math.exp(- (m / s) ** 2 / 2) / math.sqrt(2 * math.pi) / norm.cdf(-m / s) + m) 38 | else: 39 | margin.append(-3 * s) 40 | 41 | return torch.FloatTensor(margin).to(std.device) 42 | 43 | class Distiller(nn.Module): 44 | def __init__(self, t_net, s_net): 45 | super(Distiller, self).__init__() 46 | 47 | t_channels = t_net.get_channel_num() 48 | s_channels = s_net.get_channel_num() 49 | 50 | self.Connectors = nn.ModuleList([build_feature_connector(t, s) for t, s in zip(t_channels, s_channels)]) 51 | 52 | teacher_bns = t_net.get_bn_before_relu() 53 | margins = [get_margin_from_BN(bn) for bn in teacher_bns] 54 | for i, margin in enumerate(margins): 55 | self.register_buffer('margin%d' % (i+1), margin.unsqueeze(1).unsqueeze(2).unsqueeze(0).detach()) 56 | 57 | self.t_net = t_net 58 | self.s_net = s_net 59 | 60 | def forward(self, x): 61 | 62 | t_feats, t_out = self.t_net.extract_feature(x, preReLU=True) 63 | s_feats, s_out = self.s_net.extract_feature(x, preReLU=True) 64 | feat_num = len(t_feats) 65 | 66 | loss_distill = 0 67 | for i in range(feat_num): 68 | s_feats[i] = self.Connectors[i](s_feats[i]) 69 | loss_distill += distillation_loss(s_feats[i], t_feats[i].detach(), getattr(self, 'margin%d' % (i+1))) \ 70 | / 2 ** (feat_num - i - 1) 71 | 72 | return s_out, loss_distill 73 | -------------------------------------------------------------------------------- /ImageNet/models/MobileNet.py: -------------------------------------------------------------------------------- 1 | '''MobileNet in PyTorch. 2 | See the paper "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 3 | for more details. 4 | ''' 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | class MobileNet(nn.Module): 10 | def __init__(self): 11 | super(MobileNet, self).__init__() 12 | 13 | def conv_bn(inp, oup, stride): 14 | return nn.Sequential( 15 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 16 | nn.BatchNorm2d(oup), 17 | nn.ReLU(inplace=True) 18 | ) 19 | 20 | def conv_dw(inp, oup, stride): 21 | return nn.Sequential( 22 | nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), 23 | nn.BatchNorm2d(inp), 24 | nn.ReLU(inplace=True), 25 | 26 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 27 | nn.BatchNorm2d(oup), 28 | nn.ReLU(inplace=True), 29 | ) 30 | 31 | self.model = nn.Sequential( 32 | conv_bn(3, 32, 2), 33 | conv_dw(32, 64, 1), 34 | conv_dw(64, 128, 2), 35 | conv_dw(128, 128, 1), 36 | conv_dw(128, 256, 2), 37 | conv_dw(256, 256, 1), 38 | conv_dw(256, 512, 2), 39 | conv_dw(512, 512, 1), 40 | conv_dw(512, 512, 1), 41 | conv_dw(512, 512, 1), 42 | conv_dw(512, 512, 1), 43 | conv_dw(512, 512, 1), 44 | conv_dw(512, 1024, 2), 45 | conv_dw(1024, 1024, 1), 46 | nn.AvgPool2d(7), 47 | ) 48 | self.fc = nn.Linear(1024, 1000) 49 | 50 | def forward(self, x): 51 | x = self.model(x) 52 | x = x.view(-1, 1024) 53 | x = self.fc(x) 54 | return x 55 | 56 | def get_bn_before_relu(self): 57 | bn1 = self.model[3][-2] 58 | bn2 = self.model[5][-2] 59 | bn3 = self.model[11][-2] 60 | bn4 = self.model[13][-2] 61 | 62 | return [bn1, bn2, bn3, bn4] 63 | 64 | def get_channel_num(self): 65 | 66 | return [128, 256, 512, 1024] 67 | 68 | def extract_feature(self, x, preReLU=False): 69 | 70 | feat1 = self.model[3][:-1](self.model[0:3](x)) 71 | feat2 = self.model[5][:-1](self.model[4:5](F.relu(feat1))) 72 | feat3 = self.model[11][:-1](self.model[6:11](F.relu(feat2))) 73 | feat4 = self.model[13][:-1](self.model[12:13](F.relu(feat3))) 74 | 75 | out = self.model[14](F.relu(feat4)) 76 | out = out.view(-1, 1024) 77 | out = self.fc(out) 78 | 79 | if not preReLU: 80 | feat1 = F.relu(feat1) 81 | feat2 = F.relu(feat2) 82 | feat3 = F.relu(feat3) 83 | feat4 = F.relu(feat4) 84 | 85 | return [feat1, feat2, feat3, feat4], out 86 | -------------------------------------------------------------------------------- /ImageNet/models/ResNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | import torch.nn.functional as F 6 | 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | 'resnet152'] 10 | 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | } 19 | 20 | 21 | def conv3x3(in_planes, out_planes, stride=1): 22 | """3x3 convolution with padding""" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 24 | padding=1, bias=False) 25 | 26 | 27 | class BasicBlock(nn.Module): 28 | expansion = 1 29 | 30 | def __init__(self, inplanes, planes, stride=1, downsample=None): 31 | super(BasicBlock, self).__init__() 32 | self.conv1 = conv3x3(inplanes, planes, stride) 33 | self.bn1 = nn.BatchNorm2d(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes) 36 | self.bn2 = nn.BatchNorm2d(planes) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | x = F.relu(x) 42 | residual = x 43 | 44 | out = self.conv1(x) 45 | out = self.bn1(out) 46 | out = self.relu(out) 47 | 48 | out = self.conv2(out) 49 | out = self.bn2(out) 50 | 51 | if self.downsample is not None: 52 | residual = self.downsample(x) 53 | 54 | out += residual 55 | #out = self.relu(out) 56 | 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None): 64 | super(Bottleneck, self).__init__() 65 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 66 | self.bn1 = nn.BatchNorm2d(planes) 67 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 68 | padding=1, bias=False) 69 | self.bn2 = nn.BatchNorm2d(planes) 70 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 71 | self.bn3 = nn.BatchNorm2d(planes * 4) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.stride = stride 75 | 76 | def forward(self, x): 77 | x = F.relu(x) 78 | residual = x 79 | 80 | out = self.conv1(x) 81 | out = self.bn1(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv2(out) 85 | out = self.bn2(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | 91 | if self.downsample is not None: 92 | residual = self.downsample(x) 93 | 94 | out += residual 95 | #out = self.relu(out) 96 | 97 | return out 98 | 99 | 100 | class ResNet(nn.Module): 101 | 102 | def __init__(self, block, layers, num_classes=1000): 103 | self.inplanes = 64 104 | super(ResNet, self).__init__() 105 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 106 | bias=False) 107 | self.bn1 = nn.BatchNorm2d(64) 108 | self.relu = nn.ReLU(inplace=True) 109 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 110 | self.layer1 = self._make_layer(block, 64, layers[0]) 111 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 112 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 113 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 114 | self.avgpool = nn.AvgPool2d(7, stride=1) 115 | self.fc = nn.Linear(512 * block.expansion, num_classes) 116 | 117 | for m in self.modules(): 118 | if isinstance(m, nn.Conv2d): 119 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 120 | m.weight.data.normal_(0, math.sqrt(2. / n)) 121 | elif isinstance(m, nn.BatchNorm2d): 122 | m.weight.data.fill_(1) 123 | m.bias.data.zero_() 124 | 125 | def _make_layer(self, block, planes, blocks, stride=1): 126 | downsample = None 127 | if stride != 1 or self.inplanes != planes * block.expansion: 128 | downsample = nn.Sequential( 129 | nn.Conv2d(self.inplanes, planes * block.expansion, 130 | kernel_size=1, stride=stride, bias=False), 131 | nn.BatchNorm2d(planes * block.expansion), 132 | ) 133 | 134 | layers = [] 135 | layers.append(block(self.inplanes, planes, stride, downsample)) 136 | self.inplanes = planes * block.expansion 137 | for i in range(1, blocks): 138 | layers.append(block(self.inplanes, planes)) 139 | return nn.Sequential(*layers) 140 | 141 | def forward(self, x): 142 | x = self.conv1(x) 143 | x = self.bn1(x) 144 | x = self.relu(x) 145 | x = self.maxpool(x) 146 | 147 | x = self.layer1(x) 148 | x = self.layer2(x) 149 | x = self.layer3(x) 150 | x = F.relu(self.layer4(x)) 151 | 152 | x = self.avgpool(x) 153 | x = x.view(x.size(0), -1) 154 | x = self.fc(x) 155 | 156 | return x 157 | 158 | def get_bn_before_relu(self): 159 | if isinstance(self.layer1[0], Bottleneck): 160 | bn1 = self.layer1[-1].bn3 161 | bn2 = self.layer2[-1].bn3 162 | bn3 = self.layer3[-1].bn3 163 | bn4 = self.layer4[-1].bn3 164 | elif isinstance(self.layer1[0], BasicBlock): 165 | bn1 = self.layer1[-1].bn2 166 | bn2 = self.layer2[-1].bn2 167 | bn3 = self.layer3[-1].bn2 168 | bn4 = self.layer4[-1].bn2 169 | else: 170 | print('ResNet unknown block error !!!') 171 | 172 | return [bn1, bn2, bn3, bn4] 173 | 174 | def get_channel_num(self): 175 | 176 | return [256, 512, 1024, 2048] 177 | 178 | def extract_feature(self, x, preReLU=False): 179 | 180 | x = self.conv1(x) 181 | x = self.bn1(x) 182 | x = self.relu(x) 183 | x = self.maxpool(x) 184 | 185 | feat1 = self.layer1(x) 186 | feat2 = self.layer2(feat1) 187 | feat3 = self.layer3(feat2) 188 | feat4 = self.layer4(feat3) 189 | 190 | x = self.avgpool(F.relu(feat4)) 191 | x = x.view(x.size(0), -1) 192 | out = self.fc(x) 193 | 194 | if not preReLU: 195 | feat1 = F.relu(feat1) 196 | feat2 = F.relu(feat2) 197 | feat3 = F.relu(feat3) 198 | feat4 = F.relu(feat4) 199 | 200 | return [feat1, feat2, feat3, feat4], out 201 | 202 | 203 | def resnet18(pretrained=False, **kwargs): 204 | """Constructs a ResNet-18 model. 205 | Args: 206 | pretrained (bool): If True, returns a model pre-trained on ImageNet 207 | """ 208 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 209 | if pretrained: 210 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 211 | return model 212 | 213 | 214 | def resnet34(pretrained=False, **kwargs): 215 | """Constructs a ResNet-34 model. 216 | Args: 217 | pretrained (bool): If True, returns a model pre-trained on ImageNet 218 | """ 219 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 220 | if pretrained: 221 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 222 | return model 223 | 224 | 225 | def resnet50(pretrained=False, **kwargs): 226 | """Constructs a ResNet-50 model. 227 | Args: 228 | pretrained (bool): If True, returns a model pre-trained on ImageNet 229 | """ 230 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 231 | if pretrained: 232 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 233 | return model 234 | 235 | 236 | def resnet101(pretrained=False, **kwargs): 237 | """Constructs a ResNet-101 model. 238 | Args: 239 | pretrained (bool): If True, returns a model pre-trained on ImageNet 240 | """ 241 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 242 | if pretrained: 243 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 244 | return model 245 | 246 | 247 | def resnet152(pretrained=False, **kwargs): 248 | """Constructs a ResNet-152 model. 249 | Args: 250 | pretrained (bool): If True, returns a model pre-trained on ImageNet 251 | """ 252 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 253 | if pretrained: 254 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 255 | return model -------------------------------------------------------------------------------- /ImageNet/train_with_distillation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | import time 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.backends.cudnn as cudnn 9 | import torchvision.transforms as transforms 10 | import torchvision.datasets as datasets 11 | import gc 12 | 13 | import models.MobileNet as Mov 14 | import models.ResNet as ResNet 15 | import distiller 16 | 17 | import warnings 18 | warnings.filterwarnings("ignore") 19 | 20 | parser = argparse.ArgumentParser(description='PyTorch ImageNet-1k Training') 21 | parser.add_argument('--data_path', type=str, help='path to dataset') 22 | parser.add_argument('--net_type', default='resnet', type=str, help='networktype: resnet, mobilenet') 23 | parser.add_argument('-j', '--workers', default=8, type=int, help='number of data loading workers (default: 4)') 24 | parser.add_argument('--epochs', default=100, type=int, help='number of total epochs to run') 25 | parser.add_argument('-b', '--batch_size', default=256, type=int, help='mini-batch size (default: 256)') 26 | parser.add_argument('--lr', '--learning_rate', default=0.1, type=float, help='initial learning rate') 27 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 28 | parser.add_argument('--weight_decay', default=1e-4, type=float, help='weight decay (default: 1e-4)') 29 | parser.add_argument('--print_freq', default=500, type=int, help='print frequency (default: 500)') 30 | 31 | best_err1 = 100 32 | best_err5 = 100 33 | 34 | def main(): 35 | global args, best_err1, best_err5 36 | args = parser.parse_args() 37 | 38 | traindir = os.path.join(args.data_path, 'train') 39 | valdir = os.path.join(args.data_path, 'val') 40 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 41 | std=[0.229, 0.224, 0.225]) 42 | 43 | train_dataset = datasets.ImageFolder(traindir, 44 | transforms.Compose([ 45 | transforms.RandomResizedCrop(224), 46 | transforms.RandomHorizontalFlip(), 47 | transforms.ToTensor(), 48 | normalize]) 49 | ) 50 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, 51 | num_workers=args.workers, pin_memory=True, sampler=None) 52 | val_dataset = datasets.ImageFolder(valdir, 53 | transforms.Compose([ 54 | transforms.Resize(256), 55 | transforms.CenterCrop(224), 56 | transforms.ToTensor(), 57 | normalize]) 58 | ) 59 | val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, 60 | num_workers=args.workers, pin_memory=True) 61 | 62 | if args.net_type == 'mobilenet': 63 | t_net = ResNet.resnet50(pretrained=True) 64 | s_net = Mov.MobileNet() 65 | elif args.net_type == 'resnet': 66 | t_net = ResNet.resnet152(pretrained=True) 67 | s_net = ResNet.resnet50(pretrained=False) 68 | else: 69 | print('undefined network type !!!') 70 | raise 71 | 72 | d_net = distiller.Distiller(t_net, s_net) 73 | 74 | print ('Teacher Net: ') 75 | print(t_net) 76 | print ('Student Net: ') 77 | print(s_net) 78 | print('the number of teacher model parameters: {}'.format(sum([p.data.nelement() for p in t_net.parameters()]))) 79 | print('the number of student model parameters: {}'.format(sum([p.data.nelement() for p in s_net.parameters()]))) 80 | 81 | t_net = torch.nn.DataParallel(t_net).cuda() 82 | s_net = torch.nn.DataParallel(s_net).cuda() 83 | d_net = torch.nn.DataParallel(d_net).cuda() 84 | 85 | # define loss function (criterion) and optimizer 86 | criterion_CE = nn.CrossEntropyLoss().cuda() 87 | optimizer = torch.optim.SGD(list(s_net.parameters()) + list(d_net.module.Connectors.parameters()), args.lr, 88 | momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) 89 | cudnn.benchmark = True 90 | 91 | print('Teacher network performance') 92 | validate(val_loader, t_net, criterion_CE, 0) 93 | 94 | for epoch in range(1, args.epochs+1): 95 | adjust_learning_rate(optimizer, epoch) 96 | 97 | # train for one epoch 98 | train_with_distill(train_loader, d_net, optimizer, criterion_CE, epoch) 99 | # evaluate on validation set 100 | err1, err5 = validate(val_loader, s_net, criterion_CE, epoch) 101 | 102 | # remember best prec@1 and save checkpoint 103 | is_best = err1 <= best_err1 104 | best_err1 = min(err1, best_err1) 105 | if is_best: 106 | best_err5 = err5 107 | print ('Current best accuracy (top-1 and 5 error):', best_err1, best_err5) 108 | save_checkpoint({ 109 | 'epoch': epoch, 110 | 'arch': args.net_type, 111 | 'state_dict': s_net.state_dict(), 112 | 'best_err1': best_err1, 113 | 'best_err5': best_err5, 114 | 'optimizer' : optimizer.state_dict(), 115 | }, is_best) 116 | gc.collect() 117 | 118 | print ('Best accuracy (top-1 and 5 error):', best_err1, best_err5) 119 | 120 | 121 | def validate(val_loader, model, criterion_CE, epoch): 122 | batch_time = AverageMeter() 123 | losses = AverageMeter() 124 | top1 = AverageMeter() 125 | top5 = AverageMeter() 126 | 127 | # switch to evaluate mode 128 | model.eval() 129 | 130 | end = time.time() 131 | for i, (input, target) in enumerate(val_loader): 132 | target = target.cuda(async=True) 133 | 134 | # for PyTorch 0.4.x, volatile=True is replaced by with torch.no.grad(), so uncomment the followings: 135 | with torch.no_grad(): 136 | input_var = torch.autograd.Variable(input) 137 | target_var = torch.autograd.Variable(target) 138 | output = model(input_var) 139 | loss = criterion_CE(output, target_var) 140 | 141 | # measure accuracy and record loss 142 | err1, err5 = accuracy(output.data, target, topk=(1, 5)) 143 | 144 | losses.update(loss.data.item(), input.size(0)) 145 | top1.update(err1.item(), input.size(0)) 146 | top5.update(err5.item(), input.size(0)) 147 | 148 | # measure elapsed time 149 | batch_time.update(time.time() - end) 150 | end = time.time() 151 | 152 | if i % args.print_freq == 0: 153 | print('Test (on val set): [Epoch {0}/{1}][Batch {2}/{3}]\t' 154 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 155 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 156 | 'Top 1-err {top1.val:.4f} ({top1.avg:.4f})\t' 157 | 'Top 5-err {top5.val:.4f} ({top5.avg:.4f})'.format( 158 | epoch, args.epochs, i, len(val_loader), batch_time=batch_time, loss=losses, top1=top1, top5=top5)) 159 | 160 | print('* Epoch: [{0}/{1}]\t Top 1-err {top1.avg:.3f} Top 5-err {top5.avg:.3f}\t Test Loss {loss.avg:.3f}' 161 | .format(epoch, args.epochs, top1=top1, top5=top5, loss=losses)) 162 | return top1.avg, top5.avg 163 | 164 | def train_with_distill(train_loader, d_net, optimizer, criterion_CE, epoch): 165 | d_net.train() 166 | d_net.module.s_net.train() 167 | d_net.module.t_net.train() 168 | 169 | train_loss = AverageMeter() 170 | top1 = AverageMeter() 171 | top5 = AverageMeter() 172 | 173 | for i, (inputs, targets) in enumerate(train_loader): 174 | targets = targets.cuda(async=True) 175 | batch_size = inputs.shape[0] 176 | outputs, loss_distill = d_net(inputs) 177 | 178 | loss_CE = criterion_CE(outputs, targets) 179 | loss = loss_CE + loss_distill.sum() / batch_size / 10000 180 | 181 | err1, err5 = accuracy(outputs.data, targets, topk=(1, 5)) 182 | 183 | train_loss.update(loss.item(), batch_size) 184 | top1.update(err1.item(), batch_size) 185 | top5.update(err5.item(), batch_size) 186 | 187 | optimizer.zero_grad() 188 | loss.backward() 189 | optimizer.step() 190 | 191 | if i % args.print_freq == 0: 192 | print('Train with distillation: [Epoch %d/%d][Batch %d/%d]\t Loss %.3f, Top 1-error %.3f, Top 5-error %.3f' % 193 | (epoch, args.epochs, i, len(train_loader), train_loss.avg, top1.avg, top5.avg)) 194 | 195 | 196 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 197 | directory = "runs/%s/"%(args.net_type) 198 | if not os.path.exists(directory): 199 | os.makedirs(directory) 200 | filename = directory + filename 201 | torch.save(state, filename) 202 | if is_best: 203 | shutil.copyfile(filename, 'runs/%s/'%(args.net_type) + 'model_best.pth.tar') 204 | 205 | 206 | class AverageMeter(object): 207 | """Computes and stores the average and current value""" 208 | def __init__(self): 209 | self.reset() 210 | 211 | def reset(self): 212 | self.val = 0 213 | self.avg = 0 214 | self.sum = 0 215 | self.count = 0 216 | 217 | def update(self, val, n=1): 218 | self.val = val 219 | self.sum += val * n 220 | self.count += n 221 | self.avg = self.sum / self.count 222 | 223 | def adjust_learning_rate(optimizer, epoch): 224 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 225 | lr = args.lr * (0.1 ** (epoch // 30)) 226 | 227 | for param_group in optimizer.param_groups: 228 | param_group['lr'] = lr 229 | 230 | def accuracy(output, target, topk=(1,)): 231 | """Computes the precision@k for the specified values of k""" 232 | maxk = max(topk) 233 | batch_size = target.size(0) 234 | 235 | _, pred = output.topk(maxk, 1, True, True) 236 | pred = pred.t() 237 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 238 | 239 | res = [] 240 | for k in topk: 241 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 242 | wrong_k = batch_size - correct_k 243 | res.append(wrong_k.mul_(100.0 / batch_size)) 244 | 245 | return res 246 | 247 | if __name__ == '__main__': 248 | main() 249 | -------------------------------------------------------------------------------- /ImageNet/utils.py: -------------------------------------------------------------------------------- 1 | # original code: https://github.com/eladhoffer/convNet.pytorch/blob/master/preprocess.py 2 | 3 | import torch 4 | import random 5 | 6 | __all__ = ["Compose", "Lighting", "ColorJitter"] 7 | 8 | 9 | class Compose(object): 10 | """Composes several transforms together. 11 | 12 | Args: 13 | transforms (list of ``Transform`` objects): list of transforms to compose. 14 | 15 | Example: 16 | >>> transforms.Compose([ 17 | >>> transforms.CenterCrop(10), 18 | >>> transforms.ToTensor(), 19 | >>> ]) 20 | """ 21 | 22 | def __init__(self, transforms): 23 | self.transforms = transforms 24 | 25 | def __call__(self, img): 26 | for t in self.transforms: 27 | img = t(img) 28 | return img 29 | 30 | def __repr__(self): 31 | format_string = self.__class__.__name__ + '(' 32 | for t in self.transforms: 33 | format_string += '\n' 34 | format_string += ' {0}'.format(t) 35 | format_string += '\n)' 36 | return format_string 37 | 38 | 39 | class Lighting(object): 40 | """Lighting noise(AlexNet - style PCA - based noise)""" 41 | 42 | def __init__(self, alphastd, eigval, eigvec): 43 | self.alphastd = alphastd 44 | self.eigval = torch.Tensor(eigval) 45 | self.eigvec = torch.Tensor(eigvec) 46 | 47 | def __call__(self, img): 48 | if self.alphastd == 0: 49 | return img 50 | 51 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 52 | rgb = self.eigvec.type_as(img).clone() \ 53 | .mul(alpha.view(1, 3).expand(3, 3)) \ 54 | .mul(self.eigval.view(1, 3).expand(3, 3)) \ 55 | .sum(1).squeeze() 56 | 57 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 58 | 59 | 60 | class Grayscale(object): 61 | 62 | def __call__(self, img): 63 | gs = img.clone() 64 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2]) 65 | gs[1].copy_(gs[0]) 66 | gs[2].copy_(gs[0]) 67 | return gs 68 | 69 | 70 | class Saturation(object): 71 | 72 | def __init__(self, var): 73 | self.var = var 74 | 75 | def __call__(self, img): 76 | gs = Grayscale()(img) 77 | alpha = random.uniform(-self.var, self.var) 78 | return img.lerp(gs, alpha) 79 | 80 | 81 | class Brightness(object): 82 | 83 | def __init__(self, var): 84 | self.var = var 85 | 86 | def __call__(self, img): 87 | gs = img.new().resize_as_(img).zero_() 88 | alpha = random.uniform(-self.var, self.var) 89 | return img.lerp(gs, alpha) 90 | 91 | 92 | class Contrast(object): 93 | 94 | def __init__(self, var): 95 | self.var = var 96 | 97 | def __call__(self, img): 98 | gs = Grayscale()(img) 99 | gs.fill_(gs.mean()) 100 | alpha = random.uniform(-self.var, self.var) 101 | return img.lerp(gs, alpha) 102 | 103 | 104 | class ColorJitter(object): 105 | 106 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 107 | self.brightness = brightness 108 | self.contrast = contrast 109 | self.saturation = saturation 110 | 111 | def __call__(self, img): 112 | self.transforms = [] 113 | if self.brightness != 0: 114 | self.transforms.append(Brightness(self.brightness)) 115 | if self.contrast != 0: 116 | self.transforms.append(Contrast(self.contrast)) 117 | if self.saturation != 0: 118 | self.transforms.append(Saturation(self.saturation)) 119 | 120 | random.shuffle(self.transforms) 121 | transform = Compose(self.transforms) 122 | # print(transform) 123 | return transform(img) 124 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019-present NAVER Corp. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Official PyTorch implementation of "A Comprehensive Overhaul of Feature Distillation" 2 | Copyright (c) 2019-present NAVER Corp. 3 | 4 | This project contains subcomponents with separate copyright notices and license terms. 5 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. 6 | 7 | ===== 8 | 9 | pytorch/vision 10 | https://github.com/pytorch/vision 11 | 12 | 13 | BSD 3-Clause License 14 | 15 | Copyright (c) Soumith Chintala 2016, 16 | All rights reserved. 17 | 18 | Redistribution and use in source and binary forms, with or without 19 | modification, are permitted provided that the following conditions are met: 20 | 21 | * Redistributions of source code must retain the above copyright notice, this 22 | list of conditions and the following disclaimer. 23 | 24 | * Redistributions in binary form must reproduce the above copyright notice, 25 | this list of conditions and the following disclaimer in the documentation 26 | and/or other materials provided with the distribution. 27 | 28 | * Neither the name of the copyright holder nor the names of its 29 | contributors may be used to endorse or promote products derived from 30 | this software without specific prior written permission. 31 | 32 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 33 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 34 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 35 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 36 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 37 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 38 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 39 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 40 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 41 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 42 | 43 | ===== 44 | 45 | xternalz/WideResNet-pytorch 46 | https://github.com/xternalz/WideResNet-pytorch 47 | 48 | 49 | MIT License 50 | 51 | Copyright (c) 2019 xternalz 52 | 53 | Permission is hereby granted, free of charge, to any person obtaining a copy 54 | of this software and associated documentation files (the "Software"), to deal 55 | in the Software without restriction, including without limitation the rights 56 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 57 | copies of the Software, and to permit persons to whom the Software is 58 | furnished to do so, subject to the following conditions: 59 | 60 | The above copyright notice and this permission notice shall be included in all 61 | copies or substantial portions of the Software. 62 | 63 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 64 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 65 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 66 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 67 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 68 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 69 | SOFTWARE. 70 | 71 | ===== 72 | 73 | eladhoffer/convNet.pytorch 74 | https://github.com/eladhoffer/convNet.pytorch 75 | 76 | 77 | 78 | MIT License 79 | 80 | Copyright (c) 2017 Elad Hoffer 81 | 82 | Permission is hereby granted, free of charge, to any person obtaining a copy 83 | of this software and associated documentation files (the "Software"), to deal 84 | in the Software without restriction, including without limitation the rights 85 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 86 | copies of the Software, and to permit persons to whom the Software is 87 | furnished to do so, subject to the following conditions: 88 | 89 | The above copyright notice and this permission notice shall be included in all 90 | copies or substantial portions of the Software. 91 | 92 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 93 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 94 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 95 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 96 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 97 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 98 | SOFTWARE. 99 | 100 | ===== 101 | 102 | pytorch/examples 103 | https://github.com/pytorch/examples/blob/master/LICENSE 104 | 105 | 106 | BSD 3-Clause License 107 | 108 | Copyright (c) 2017, 109 | All rights reserved. 110 | 111 | Redistribution and use in source and binary forms, with or without 112 | modification, are permitted provided that the following conditions are met: 113 | 114 | * Redistributions of source code must retain the above copyright notice, this 115 | list of conditions and the following disclaimer. 116 | 117 | * Redistributions in binary form must reproduce the above copyright notice, 118 | this list of conditions and the following disclaimer in the documentation 119 | and/or other materials provided with the distribution. 120 | 121 | * Neither the name of the copyright holder nor the names of its 122 | contributors may be used to endorse or promote products derived from 123 | this software without specific prior written permission. 124 | 125 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 126 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 127 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 128 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 129 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 130 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 131 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 132 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 133 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 134 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 135 | 136 | ===== 137 | 138 | jfzhang95/pytorch-deeplab-xception 139 | https://github.com/jfzhang95/pytorch-deeplab-xception 140 | 141 | 142 | 143 | MIT License 144 | 145 | Copyright (c) 2018 Pyjcsx 146 | 147 | Permission is hereby granted, free of charge, to any person obtaining a copy 148 | of this software and associated documentation files (the "Software"), to deal 149 | in the Software without restriction, including without limitation the rights 150 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 151 | copies of the Software, and to permit persons to whom the Software is 152 | furnished to do so, subject to the following conditions: 153 | 154 | The above copyright notice and this permission notice shall be included in all 155 | copies or substantial portions of the Software. 156 | 157 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 158 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 159 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 160 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 161 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 162 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 163 | SOFTWARE. 164 | 165 | ===== -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A Comprehensive Overhaul of Feature Distillation 2 | **Accepted at ICCV 2019** 3 | 4 | Official PyTorch implementation of "A Comprehensive Overhaul of Feature Distillation" | [paper](https://arxiv.org/abs/1904.01866) | [project page](https://sites.google.com/view/byeongho-heo/overhaul) | [blog](https://clova-ai.blog/2019/08/22/a-comprehensive-overhaul-of-feature-distillation-iccv-2019) 5 | 6 | Byeongho Heo, Jeesoo Kim, Sangdoo Yun, Hyojin Park, Nojun Kwak, Jin Young Choi 7 | 8 | Clova AI Research, NAVER Corp. \ 9 | Seoul National University 10 | 11 | ## Requirements 12 | - Python3 13 | - PyTorch (> 0.4.1) 14 | - torchvision 15 | - numpy 16 | - scipy 17 | 18 | ## Updates 19 | ***19 Nov 2019*** Segmentation released 20 | 21 | ***10 Sep 2019*** Initial upload 22 | 23 | ## CIFAR-100 24 | 25 | ### Settings 26 | We provide the code of the experimental settings specified in the paper. 27 | 28 | | Setup | Compression type | Teacher | Student | Teacher size | Student size | Size ratio | 29 | |:-----:|:----------------:|:-----------:|:-----------:|:------------:|:------------:|:----------:| 30 | | (a) | Depth | WRN 28-4 | WRN 16-4 | 5.87M | 2.77M | 47.2% | 31 | | (b) | Channel | WRN 28-4 | WRN 28-2 | 5.87M | 1.47M | 25.0% | 32 | | (c) | Depth & channel | WRN 28-4 | WRN 16-2 | 5.87M | 0.70M | 11.9% | 33 | | (d) | Architecture | WRN 28-4 | ResNet 56 | 5.87M | 0.86M | 14.7% | 34 | | (e) | Architecture | Pyramid-200 | WRN 28-4 | 26.84M | 5.87M | 21.9% | 35 | | (f) | Architecture | Pyramid-200 | Pyramid-110 | 26.84M | 3.91M | 14.6% | 36 | 37 | ### Teacher models 38 | Download following pre-trained teacher network and put them into ```./data``` directory 39 | - [Wide Residual Network 28-4](https://drive.google.com/open?id=1Quxgs5teXVXwD3jBdkk-WeNLNpxbiZXN) 40 | - [PyramidNet-200(240)](https://drive.google.com/open?id=1_QgG81fNM3OvVIbMAxDPykKWuSIyKnmz) 41 | 42 | ### Training 43 | Run ```CIFAR-100/train_with_distillation.py``` with setting alphabet (a - f) 44 | ``` 45 | cd CIFAR-100 46 | python train_with_distillation.py \ 47 | --setting a \ 48 | --epochs 200 \ 49 | --batch_size 128 \ 50 | --lr 0.1 \ 51 | --momentum 0.9 \ 52 | --weight_decay 5e-4 53 | ``` 54 | 55 | For pyramid teacher (e, f), we used batch-size 64 to save gpu memory. 56 | ``` 57 | cd CIFAR-100 58 | python train_with_distillation.py \ 59 | --setting e \ 60 | --epochs 200 \ 61 | --batch_size 64 \ 62 | --lr 0.1 \ 63 | --momentum 0.9 \ 64 | --weight_decay 5e-4 65 | ``` 66 | 67 | ### Experimental results 68 | 69 | Performance measure is classification error rate (%) 70 | 71 | 72 | | Setup | Teacher | Student | Original | Proposed | Improvement | 73 | |:-----:|:-----------:|:-----------:|:--------:|:--------:|:-----------:| 74 | | (a) | WRN 28-4 | WRN 16-4 | 22.72% | 20.89% | 1.83% | 75 | | (b) | WRN 28-4 | WRN 28-2 | 24.88% | 21.98% | 2.90% | 76 | | (c) | WRN 28-4 | WRN 16-2 | 27.32% | 24.08% | 3.24% | 77 | | (d) | WRN 28-4 | ResNet 56 | 27.68% | 24.44% | 3.24% | 78 | | (f) | Pyramid-200 | WRN 28-4 | 21.09% | 17.80% | 3.29% | 79 | | (g) | Pyramid-200 | Pyramid-110 | 22.58% | 18.89% | 3.69% | 80 | 81 | ## ImageNet 82 | 83 | ### Settings 84 | 85 | | Setup | Compression type | Teacher | Student | Teacher size | Student size | Size ratio | 86 | |:-----:|:----------------:|:----------:|:---------:|:------------:|:------------:|:----------:| 87 | | (a) | Depth | ResNet 152 | ResNet 50 | 60.19M | 25.56M | 42.47% | 88 | | (b) | Architecture | ResNet 50 | MobileNet | 25.56M | 4.23M | 16.55% | 89 | 90 | 91 | In case of ImageNet, teacher model will be automatically downloaded from PyTorch sites. 92 | 93 | ### Training 94 | 95 | - (a) : ResNet152 to ResNet50 96 | ```shell script 97 | cd ImageNet 98 | python train_with_distillation.py \ 99 | --data_path your/path/to/ImageNet \ 100 | --net_type resnet \ 101 | --epochs 100 \ 102 | --lr 0.1 \ 103 | --batch_size 256 104 | ``` 105 | 106 | - (b) : ResNet50 to MobileNet 107 | ```shell script 108 | cd ImageNet 109 | python train_with_distillation.py \ 110 | --data_path your/path/to/ImageNet \ 111 | --net_type mobilenet \ 112 | --epochs 100 \ 113 | --lr 0.1 \ 114 | --batch_size 256 115 | ``` 116 | 117 | ### Experimental results 118 | 119 | - ResNet 50 120 | 121 | | Network | Method | Top1-error | Top5-error | 122 | |:----------:|:--------:|:----------:|:----------:| 123 | | ResNet 152 | Teacher | 21.69 | 5.95 | 124 | | ResNet 50 | Original | 23.72 | 6.97 | 125 | | ResNet 50 | Proposed | __21.65__ | __5.83__ | 126 | 127 | - MobileNet 128 | 129 | | Network | Method | Top1-error | Top5-error | 130 | |:---------:|:--------:|:-----:|:-----:| 131 | | ResNet 50 | Teacher | 23.84 | 7.14 | 132 | | Mobilenet | Original | 31.13 | 11.24 | 133 | | Mobilenet | Proposed | __28.75__ | __9.66__ | 134 | 135 | ## Segmentation - Pascal VOC 136 | 137 | Our segmentation code is based on [pytorch-deeplab-xception](https://github.com/jfzhang95/pytorch-deeplab-xception). 138 | 139 | ### Additional requirements 140 | 141 | - tqdm 142 | - matplotlib 143 | - pillow 144 | 145 | ### Settings 146 | 147 | | Teacher | Student | Teacher size | Student size | Size ratio | 148 | |:----------:|:---------:|:------------:|:------------:|:----------:| 149 | | ResNet 101 | ResNet 18 | 59.3M | 16.6 | 28.0% | 150 | | ResNet 101 | MobileNetV2 | 59.3M | 5.8M | 9.8% | 151 | 152 | 153 | ### Teacher models 154 | Download following pre-trained teacher network and put it into ```./Segmentation/pretrained``` directory 155 | - [ResNet101-DeepLabV3+](https://drive.google.com/open?id=1Pz2OT5KoSNvU5rc3w5d2R8_0OBkKSkLR) 156 | 157 | We used pre-trained model in [pytorch-deeplab-xception](https://github.com/jfzhang95/pytorch-deeplab-xception) for teacher network. 158 | 159 | ### Training 160 | 161 | - First, move to segmentation folder : ```cd Segmentation``` 162 | - Next, configure your dataset path on ```Segmentation/mypath.py``` 163 | 164 | - Without distillation 165 | - ResNet 18 166 | ```shell script 167 | CUDA_VISIBLE_DEVICES=0,1 python train.py --backbone resnet18 --gpu-ids 0,1 --dataset pascal --use-sbd --nesterov 168 | ``` 169 | 170 | - MobileNetV2 171 | ```shell script 172 | CUDA_VISIBLE_DEVICES=0,1 python train.py --backbone mobilenet --gpu-ids 0,1 --dataset pascal --use-sbd --nesterov 173 | ```` 174 | 175 | - Distillation 176 | - ResNet 18 177 | ```shell script 178 | CUDA_VISIBLE_DEVICES=0,1 python train_with_distillation.py --backbone resnet18 --gpu-ids 0,1 --dataset pascal --use-sbd --nesterov 179 | ``` 180 | 181 | -MobileNetV2 182 | ```shell script 183 | CUDA_VISIBLE_DEVICES=0,1 python train_with_distillation.py --backbone mobilenet --gpu-ids 0,1 --dataset pascal --use-sbd --nesterov 184 | ``` 185 | 186 | ### Experimental results 187 | 188 | This numbers are based validation performance of our code. 189 | 190 | - ResNet 18 191 | 192 | | Network | Method | mIOU | 193 | |:----------:|:--------:|:----------:| 194 | | ResNet 101 | Teacher | 77.89 | 195 | | ResNet 18 | Original | 72.07 | 196 | | ResNet 18 | Proposed | __73.98__ | 197 | 198 | - MobileNetV2 199 | 200 | | Network | Method | mIOU | 201 | |:---------:|:--------:|:-----:| 202 | | ResNet 101 | Teacher | 77.89 | 203 | | MobileNetV2 | Original | 68.46 | 204 | | MobileNetV2 | Proposed | __71.19__ | 205 | 206 | 207 | In the paper, we reported performance on the **test** set, but our code measures the performance on the **val** set. 208 | Therefore, the performance on code is not same as the paper. 209 | If you want accurate measure, please measure performance on **test** set with [Pascal VOC evaluation server](http://host.robots.ox.ac.uk/pascal/VOC/). 210 | 211 | 212 | ## Citation 213 | 214 | ``` 215 | @inproceedings{heo2019overhaul, 216 | title={A Comprehensive Overhaul of Feature Distillation}, 217 | author={Heo, Byeongho and Kim, Jeesoo and Yun, Sangdoo and Park, Hyojin and Kwak, Nojun and Choi, Jin Young}, 218 | booktitle = {International Conference on Computer Vision (ICCV)}, 219 | year={2019} 220 | } 221 | ``` 222 | 223 | ## License 224 | 225 | ``` 226 | Copyright (c) 2019-present NAVER Corp. 227 | 228 | Permission is hereby granted, free of charge, to any person obtaining a copy 229 | of this software and associated documentation files (the "Software"), to deal 230 | in the Software without restriction, including without limitation the rights 231 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 232 | copies of the Software, and to permit persons to whom the Software is 233 | furnished to do so, subject to the following conditions: 234 | 235 | The above copyright notice and this permission notice shall be included in 236 | all copies or substantial portions of the Software. 237 | 238 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 239 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 240 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 241 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 242 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 243 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 244 | THE SOFTWARE. 245 | ``` 246 | -------------------------------------------------------------------------------- /Segmentation/README.md: -------------------------------------------------------------------------------- 1 | ## Segmentation - Pascal VOC 2 | 3 | Our segmentation code is based on [pytorch-deeplab-xception](https://github.com/jfzhang95/pytorch-deeplab-xception). 4 | 5 | ### Additional requirements 6 | 7 | - tqdm 8 | - matplotlib 9 | - pillow 10 | 11 | ### Settings 12 | 13 | | Teacher | Student | Teacher size | Student size | Size ratio | 14 | |:----------:|:---------:|:------------:|:------------:|:----------:| 15 | | ResNet 101 | ResNet 18 | 59.3M | 16.6 | 28.0% | 16 | | ResNet 101 | MobileNetV2 | 59.3M | 5.8M | 9.8% | 17 | 18 | 19 | ### Teacher models 20 | Download following pre-trained teacher network and put it into ```./Segmentation/pretrained``` directory 21 | - [ResNet101-DeepLabV3+](https://drive.google.com/open?id=1Pz2OT5KoSNvU5rc3w5d2R8_0OBkKSkLR) 22 | 23 | We used pre-trained model in [pytorch-deeplab-xception](https://github.com/jfzhang95/pytorch-deeplab-xception) for teacher network. 24 | 25 | ### Training 26 | 27 | - First, move to segmentation folder : ```cd Segmentation``` 28 | - Next, configure your dataset path on ```Segmentation/mypath.py``` 29 | 30 | - Without distillation 31 | - ResNet 18 32 | ```shell script 33 | CUDA_VISIBLE_DEVICES=0,1 python train.py --backbone resnet18 --gpu-ids 0,1 --dataset pascal --use-sbd --nesterov 34 | ``` 35 | 36 | - MobileNetV2 37 | ```shell script 38 | CUDA_VISIBLE_DEVICES=0,1 python train.py --backbone mobilenet --gpu-ids 0,1 --dataset pascal --use-sbd --nesterov 39 | ```` 40 | 41 | - Distillation 42 | - ResNet 18 43 | ```shell script 44 | CUDA_VISIBLE_DEVICES=0,1 python train_with_distillation.py --backbone resnet18 --gpu-ids 0,1 --dataset pascal --use-sbd --nesterov 45 | ``` 46 | 47 | -MobileNetV2 48 | ```shell script 49 | CUDA_VISIBLE_DEVICES=0,1 python train_with_distillation.py --backbone mobilenet --gpu-ids 0,1 --dataset pascal --use-sbd --nesterov 50 | ``` 51 | 52 | ### Experimental results 53 | 54 | This numbers are based validation performance of our code. 55 | 56 | - ResNet 18 57 | 58 | | Network | Method | mIOU | 59 | |:----------:|:--------:|:----------:| 60 | | ResNet 101 | Teacher | 77.89 | 61 | | ResNet 18 | Original | 72.07 | 62 | | ResNet 18 | Proposed | __73.98__ | 63 | 64 | - MobileNetV2 65 | 66 | | Network | Method | mIOU | 67 | |:---------:|:--------:|:-----:| 68 | | ResNet 101 | Teacher | 77.89 | 69 | | MobileNetV2 | Original | 68.46 | 70 | | MobileNetV2 | Proposed | __71.19__ | 71 | 72 | 73 | In the paper, we reported performance on the **test** set, but our code measures the performance on the **val** set. 74 | Therefore, the performance on code is not same as the paper. 75 | If you want accurate measure, please measure performance on **test** set with [Pascal VOC evaluation server](http://host.robots.ox.ac.uk/pascal/VOC/). 76 | -------------------------------------------------------------------------------- /Segmentation/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from dataloaders.datasets import cityscapes, coco, combine_dbs, pascal, sbd 2 | from torch.utils.data import DataLoader 3 | 4 | def make_data_loader(args, **kwargs): 5 | 6 | if args.dataset == 'pascal': 7 | train_set = pascal.VOCSegmentation(args, split='train') 8 | val_set = pascal.VOCSegmentation(args, split='val') 9 | if args.use_sbd: 10 | sbd_train = sbd.SBDSegmentation(args, split=['train', 'val']) 11 | train_set = combine_dbs.CombineDBs([train_set, sbd_train], excluded=[val_set]) 12 | 13 | num_class = train_set.NUM_CLASSES 14 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 15 | val_loader = DataLoader(val_set, batch_size=1, shuffle=False, **kwargs) 16 | test_loader = None 17 | 18 | return train_loader, val_loader, test_loader, num_class 19 | 20 | elif args.dataset == 'cityscapes': 21 | train_set = cityscapes.CityscapesSegmentation(args, split='train') 22 | val_set = cityscapes.CityscapesSegmentation(args, split='val') 23 | test_set = cityscapes.CityscapesSegmentation(args, split='test') 24 | num_class = train_set.NUM_CLASSES 25 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 26 | val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) 27 | test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs) 28 | 29 | return train_loader, val_loader, test_loader, num_class 30 | 31 | elif args.dataset == 'coco': 32 | train_set = coco.COCOSegmentation(args, split='train') 33 | val_set = coco.COCOSegmentation(args, split='val') 34 | num_class = train_set.NUM_CLASSES 35 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 36 | val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) 37 | test_loader = None 38 | return train_loader, val_loader, test_loader, num_class 39 | 40 | else: 41 | raise NotImplementedError 42 | 43 | -------------------------------------------------------------------------------- /Segmentation/dataloaders/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | from torchvision import transforms 5 | 6 | from PIL import Image, ImageOps, ImageFilter 7 | 8 | class Normalize(object): 9 | """Normalize a tensor image with mean and standard deviation. 10 | Args: 11 | mean (tuple): means for each channel. 12 | std (tuple): standard deviations for each channel. 13 | """ 14 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 15 | self.mean = mean 16 | self.std = std 17 | 18 | def __call__(self, sample): 19 | img = sample['image'] 20 | mask = sample['label'] 21 | img = np.array(img).astype(np.float32) 22 | mask = np.array(mask).astype(np.float32) 23 | img /= 255.0 24 | img -= self.mean 25 | img /= self.std 26 | 27 | return {'image': img, 28 | 'label': mask} 29 | 30 | 31 | class ToTensor(object): 32 | """Convert ndarrays in sample to Tensors.""" 33 | 34 | def __call__(self, sample): 35 | # swap color axis because 36 | # numpy image: H x W x C 37 | # torch image: C X H X W 38 | img = sample['image'] 39 | mask = sample['label'] 40 | img = np.array(img).astype(np.float32).transpose((2, 0, 1)) 41 | mask = np.array(mask).astype(np.float32) 42 | 43 | img = torch.from_numpy(img).float() 44 | mask = torch.from_numpy(mask).float() 45 | 46 | return {'image': img, 47 | 'label': mask} 48 | 49 | 50 | class RandomHorizontalFlip(object): 51 | def __call__(self, sample): 52 | img = sample['image'] 53 | mask = sample['label'] 54 | if random.random() < 0.5: 55 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 56 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 57 | 58 | return {'image': img, 59 | 'label': mask} 60 | 61 | 62 | class RandomRotate(object): 63 | def __init__(self, degree): 64 | self.degree = degree 65 | 66 | def __call__(self, sample): 67 | img = sample['image'] 68 | mask = sample['label'] 69 | rotate_degree = random.uniform(-1*self.degree, self.degree) 70 | img = img.rotate(rotate_degree, Image.BILINEAR) 71 | mask = mask.rotate(rotate_degree, Image.NEAREST) 72 | 73 | return {'image': img, 74 | 'label': mask} 75 | 76 | 77 | class RandomGaussianBlur(object): 78 | def __call__(self, sample): 79 | img = sample['image'] 80 | mask = sample['label'] 81 | if random.random() < 0.5: 82 | img = img.filter(ImageFilter.GaussianBlur( 83 | radius=random.random())) 84 | 85 | return {'image': img, 86 | 'label': mask} 87 | 88 | 89 | class RandomScaleCrop(object): 90 | def __init__(self, base_size, crop_size, fill=0): 91 | self.base_size = base_size 92 | self.crop_size = crop_size 93 | self.fill = fill 94 | 95 | def __call__(self, sample): 96 | img = sample['image'] 97 | mask = sample['label'] 98 | # random scale (short edge) 99 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) 100 | w, h = img.size 101 | if h > w: 102 | ow = short_size 103 | oh = int(1.0 * h * ow / w) 104 | else: 105 | oh = short_size 106 | ow = int(1.0 * w * oh / h) 107 | img = img.resize((ow, oh), Image.BILINEAR) 108 | mask = mask.resize((ow, oh), Image.NEAREST) 109 | # pad crop 110 | if short_size < self.crop_size: 111 | padh = self.crop_size - oh if oh < self.crop_size else 0 112 | padw = self.crop_size - ow if ow < self.crop_size else 0 113 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 114 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill) 115 | # random crop crop_size 116 | w, h = img.size 117 | x1 = random.randint(0, w - self.crop_size) 118 | y1 = random.randint(0, h - self.crop_size) 119 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 120 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 121 | 122 | return {'image': img, 123 | 'label': mask} 124 | 125 | 126 | class FixScaleCrop(object): 127 | def __init__(self, crop_size): 128 | self.crop_size = crop_size 129 | 130 | def __call__(self, sample): 131 | img = sample['image'] 132 | mask = sample['label'] 133 | w, h = img.size 134 | if w > h: 135 | oh = self.crop_size 136 | ow = int(1.0 * w * oh / h) 137 | else: 138 | ow = self.crop_size 139 | oh = int(1.0 * h * ow / w) 140 | img = img.resize((ow, oh), Image.BILINEAR) 141 | mask = mask.resize((ow, oh), Image.NEAREST) 142 | # center crop 143 | w, h = img.size 144 | x1 = int(round((w - self.crop_size) / 2.)) 145 | y1 = int(round((h - self.crop_size) / 2.)) 146 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 147 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 148 | 149 | return {'image': img, 150 | 'label': mask} 151 | 152 | class FixedResize(object): 153 | def __init__(self, size): 154 | self.image_resize = transforms.Resize(size, Image.BILINEAR) 155 | self.mask_resize = transforms.Resize(size, Image.NEAREST) 156 | 157 | def __call__(self, sample): 158 | img = sample['image'] 159 | mask = sample['label'] 160 | 161 | assert img.size == mask.size 162 | 163 | img = self.image_resize(img) 164 | mask = self.mask_resize(mask) 165 | 166 | return {'image': img, 167 | 'label': mask} -------------------------------------------------------------------------------- /Segmentation/dataloaders/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/overhaul-distillation/76344a84a7ce23c894f41a2e05b866c9b73fd85a/Segmentation/dataloaders/datasets/__init__.py -------------------------------------------------------------------------------- /Segmentation/dataloaders/datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scipy.misc as m 4 | from PIL import Image 5 | from torch.utils import data 6 | from mypath import Path 7 | from torchvision import transforms 8 | from dataloaders import custom_transforms as tr 9 | 10 | class CityscapesSegmentation(data.Dataset): 11 | NUM_CLASSES = 19 12 | 13 | def __init__(self, args, root=Path.db_root_dir('cityscapes'), split="train"): 14 | 15 | self.root = root 16 | self.split = split 17 | self.args = args 18 | self.files = {} 19 | 20 | self.images_base = os.path.join(self.root, 'leftImg8bit', self.split) 21 | self.annotations_base = os.path.join(self.root, 'gtFine_trainvaltest', 'gtFine', self.split) 22 | 23 | self.files[split] = self.recursive_glob(rootdir=self.images_base, suffix='.png') 24 | 25 | self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] 26 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] 27 | self.class_names = ['unlabelled', 'road', 'sidewalk', 'building', 'wall', 'fence', \ 28 | 'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain', \ 29 | 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', \ 30 | 'motorcycle', 'bicycle'] 31 | 32 | self.ignore_index = 255 33 | self.class_map = dict(zip(self.valid_classes, range(self.NUM_CLASSES))) 34 | 35 | if not self.files[split]: 36 | raise Exception("No files for split=[%s] found in %s" % (split, self.images_base)) 37 | 38 | print("Found %d %s images" % (len(self.files[split]), split)) 39 | 40 | def __len__(self): 41 | return len(self.files[self.split]) 42 | 43 | def __getitem__(self, index): 44 | 45 | img_path = self.files[self.split][index].rstrip() 46 | lbl_path = os.path.join(self.annotations_base, 47 | img_path.split(os.sep)[-2], 48 | os.path.basename(img_path)[:-15] + 'gtFine_labelIds.png') 49 | 50 | _img = Image.open(img_path).convert('RGB') 51 | _tmp = np.array(Image.open(lbl_path), dtype=np.uint8) 52 | _tmp = self.encode_segmap(_tmp) 53 | _target = Image.fromarray(_tmp) 54 | 55 | sample = {'image': _img, 'label': _target} 56 | 57 | if self.split == 'train': 58 | return self.transform_tr(sample) 59 | elif self.split == 'val': 60 | return self.transform_val(sample) 61 | elif self.split == 'test': 62 | return self.transform_ts(sample) 63 | 64 | def encode_segmap(self, mask): 65 | # Put all void classes to zero 66 | for _voidc in self.void_classes: 67 | mask[mask == _voidc] = self.ignore_index 68 | for _validc in self.valid_classes: 69 | mask[mask == _validc] = self.class_map[_validc] 70 | return mask 71 | 72 | def recursive_glob(self, rootdir='.', suffix=''): 73 | """Performs recursive glob with given suffix and rootdir 74 | :param rootdir is the root directory 75 | :param suffix is the suffix to be searched 76 | """ 77 | return [os.path.join(looproot, filename) 78 | for looproot, _, filenames in os.walk(rootdir) 79 | for filename in filenames if filename.endswith(suffix)] 80 | 81 | def transform_tr(self, sample): 82 | composed_transforms = transforms.Compose([ 83 | tr.RandomHorizontalFlip(), 84 | tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), 85 | tr.RandomGaussianBlur(), 86 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 87 | tr.ToTensor()]) 88 | 89 | return composed_transforms(sample) 90 | 91 | def transform_val(self, sample): 92 | 93 | composed_transforms = transforms.Compose([ 94 | tr.FixedResize(size=self.args.crop_size), 95 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 96 | tr.ToTensor()]) 97 | 98 | return composed_transforms(sample) 99 | 100 | def transform_ts(self, sample): 101 | 102 | composed_transforms = transforms.Compose([ 103 | tr.FixedResize(size=self.args.crop_size), 104 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 105 | tr.ToTensor()]) 106 | 107 | return composed_transforms(sample) 108 | 109 | if __name__ == '__main__': 110 | from dataloaders.utils import decode_segmap 111 | from torch.utils.data import DataLoader 112 | import matplotlib.pyplot as plt 113 | import argparse 114 | 115 | parser = argparse.ArgumentParser() 116 | args = parser.parse_args() 117 | args.base_size = 513 118 | args.crop_size = 513 119 | 120 | cityscapes_train = CityscapesSegmentation(args, split='train') 121 | 122 | dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2) 123 | 124 | for ii, sample in enumerate(dataloader): 125 | for jj in range(sample["image"].size()[0]): 126 | img = sample['image'].numpy() 127 | gt = sample['label'].numpy() 128 | tmp = np.array(gt[jj]).astype(np.uint8) 129 | segmap = decode_segmap(tmp, dataset='cityscapes') 130 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 131 | img_tmp *= (0.229, 0.224, 0.225) 132 | img_tmp += (0.485, 0.456, 0.406) 133 | img_tmp *= 255.0 134 | img_tmp = img_tmp.astype(np.uint8) 135 | plt.figure() 136 | plt.title('display') 137 | plt.subplot(211) 138 | plt.imshow(img_tmp) 139 | plt.subplot(212) 140 | plt.imshow(segmap) 141 | 142 | if ii == 1: 143 | break 144 | 145 | plt.show(block=True) 146 | 147 | -------------------------------------------------------------------------------- /Segmentation/dataloaders/datasets/coco.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | from mypath import Path 5 | from tqdm import trange 6 | import os 7 | from pycocotools.coco import COCO 8 | from pycocotools import mask 9 | from torchvision import transforms 10 | from dataloaders import custom_transforms as tr 11 | from PIL import Image, ImageFile 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | 14 | 15 | class COCOSegmentation(Dataset): 16 | NUM_CLASSES = 21 17 | CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 18 | 1, 64, 20, 63, 7, 72] 19 | 20 | def __init__(self, 21 | args, 22 | base_dir=Path.db_root_dir('coco'), 23 | split='train', 24 | year='2017'): 25 | super().__init__() 26 | ann_file = os.path.join(base_dir, 'annotations/instances_{}{}.json'.format(split, year)) 27 | ids_file = os.path.join(base_dir, 'annotations/{}_ids_{}.pth'.format(split, year)) 28 | self.img_dir = os.path.join(base_dir, 'images/{}{}'.format(split, year)) 29 | self.split = split 30 | self.coco = COCO(ann_file) 31 | self.coco_mask = mask 32 | if os.path.exists(ids_file): 33 | self.ids = torch.load(ids_file) 34 | else: 35 | ids = list(self.coco.imgs.keys()) 36 | self.ids = self._preprocess(ids, ids_file) 37 | self.args = args 38 | 39 | def __getitem__(self, index): 40 | _img, _target = self._make_img_gt_point_pair(index) 41 | sample = {'image': _img, 'label': _target} 42 | 43 | if self.split == "train": 44 | return self.transform_tr(sample) 45 | elif self.split == 'val': 46 | return self.transform_val(sample) 47 | 48 | def _make_img_gt_point_pair(self, index): 49 | coco = self.coco 50 | img_id = self.ids[index] 51 | img_metadata = coco.loadImgs(img_id)[0] 52 | path = img_metadata['file_name'] 53 | _img = Image.open(os.path.join(self.img_dir, path)).convert('RGB') 54 | cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id)) 55 | _target = Image.fromarray(self._gen_seg_mask( 56 | cocotarget, img_metadata['height'], img_metadata['width'])) 57 | 58 | return _img, _target 59 | 60 | def _preprocess(self, ids, ids_file): 61 | print("Preprocessing mask, this will take a while. " + \ 62 | "But don't worry, it only run once for each split.") 63 | tbar = trange(len(ids)) 64 | new_ids = [] 65 | for i in tbar: 66 | img_id = ids[i] 67 | cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id)) 68 | img_metadata = self.coco.loadImgs(img_id)[0] 69 | mask = self._gen_seg_mask(cocotarget, img_metadata['height'], 70 | img_metadata['width']) 71 | # more than 1k pixels 72 | if (mask > 0).sum() > 1000: 73 | new_ids.append(img_id) 74 | tbar.set_description('Doing: {}/{}, got {} qualified images'. \ 75 | format(i, len(ids), len(new_ids))) 76 | print('Found number of qualified images: ', len(new_ids)) 77 | torch.save(new_ids, ids_file) 78 | return new_ids 79 | 80 | def _gen_seg_mask(self, target, h, w): 81 | mask = np.zeros((h, w), dtype=np.uint8) 82 | coco_mask = self.coco_mask 83 | for instance in target: 84 | rle = coco_mask.frPyObjects(instance['segmentation'], h, w) 85 | m = coco_mask.decode(rle) 86 | cat = instance['category_id'] 87 | if cat in self.CAT_LIST: 88 | c = self.CAT_LIST.index(cat) 89 | else: 90 | continue 91 | if len(m.shape) < 3: 92 | mask[:, :] += (mask == 0) * (m * c) 93 | else: 94 | mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8) 95 | return mask 96 | 97 | def transform_tr(self, sample): 98 | composed_transforms = transforms.Compose([ 99 | tr.RandomHorizontalFlip(), 100 | tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), 101 | tr.RandomGaussianBlur(), 102 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 103 | tr.ToTensor()]) 104 | 105 | return composed_transforms(sample) 106 | 107 | def transform_val(self, sample): 108 | 109 | composed_transforms = transforms.Compose([ 110 | tr.FixedResize(size=self.args.crop_size), 111 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 112 | tr.ToTensor()]) 113 | 114 | return composed_transforms(sample) 115 | 116 | 117 | def __len__(self): 118 | return len(self.ids) 119 | 120 | 121 | 122 | if __name__ == "__main__": 123 | from dataloaders import custom_transforms as tr 124 | from dataloaders.utils import decode_segmap 125 | from torch.utils.data import DataLoader 126 | from torchvision import transforms 127 | import matplotlib.pyplot as plt 128 | import argparse 129 | 130 | parser = argparse.ArgumentParser() 131 | args = parser.parse_args() 132 | args.base_size = 513 133 | args.crop_size = 513 134 | 135 | coco_val = COCOSegmentation(args, split='val', year='2017') 136 | 137 | dataloader = DataLoader(coco_val, batch_size=4, shuffle=True, num_workers=0) 138 | 139 | for ii, sample in enumerate(dataloader): 140 | for jj in range(sample["image"].size()[0]): 141 | img = sample['image'].numpy() 142 | gt = sample['label'].numpy() 143 | tmp = np.array(gt[jj]).astype(np.uint8) 144 | segmap = decode_segmap(tmp, dataset='coco') 145 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 146 | img_tmp *= (0.229, 0.224, 0.225) 147 | img_tmp += (0.485, 0.456, 0.406) 148 | img_tmp *= 255.0 149 | img_tmp = img_tmp.astype(np.uint8) 150 | plt.figure() 151 | plt.title('display') 152 | plt.subplot(211) 153 | plt.imshow(img_tmp) 154 | plt.subplot(212) 155 | plt.imshow(segmap) 156 | 157 | if ii == 1: 158 | break 159 | 160 | plt.show(block=True) -------------------------------------------------------------------------------- /Segmentation/dataloaders/datasets/combine_dbs.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | 4 | class CombineDBs(data.Dataset): 5 | NUM_CLASSES = 21 6 | def __init__(self, dataloaders, excluded=None): 7 | self.dataloaders = dataloaders 8 | self.excluded = excluded 9 | self.im_ids = [] 10 | 11 | # Combine object lists 12 | for dl in dataloaders: 13 | for elem in dl.im_ids: 14 | if elem not in self.im_ids: 15 | self.im_ids.append(elem) 16 | 17 | # Exclude 18 | if excluded: 19 | for dl in excluded: 20 | for elem in dl.im_ids: 21 | if elem in self.im_ids: 22 | self.im_ids.remove(elem) 23 | 24 | # Get object pointers 25 | self.cat_list = [] 26 | self.im_list = [] 27 | new_im_ids = [] 28 | num_images = 0 29 | for ii, dl in enumerate(dataloaders): 30 | for jj, curr_im_id in enumerate(dl.im_ids): 31 | if (curr_im_id in self.im_ids) and (curr_im_id not in new_im_ids): 32 | num_images += 1 33 | new_im_ids.append(curr_im_id) 34 | self.cat_list.append({'db_ii': ii, 'cat_ii': jj}) 35 | 36 | self.im_ids = new_im_ids 37 | print('Combined number of images: {:d}'.format(num_images)) 38 | 39 | def __getitem__(self, index): 40 | 41 | _db_ii = self.cat_list[index]["db_ii"] 42 | _cat_ii = self.cat_list[index]['cat_ii'] 43 | sample = self.dataloaders[_db_ii].__getitem__(_cat_ii) 44 | 45 | if 'meta' in sample.keys(): 46 | sample['meta']['db'] = str(self.dataloaders[_db_ii]) 47 | 48 | return sample 49 | 50 | def __len__(self): 51 | return len(self.cat_list) 52 | 53 | def __str__(self): 54 | include_db = [str(db) for db in self.dataloaders] 55 | exclude_db = [str(db) for db in self.excluded] 56 | return 'Included datasets:'+str(include_db)+'\n'+'Excluded datasets:'+str(exclude_db) 57 | 58 | 59 | if __name__ == "__main__": 60 | import matplotlib.pyplot as plt 61 | from dataloaders.datasets import pascal, sbd 62 | from dataloaders import sbd 63 | import torch 64 | import numpy as np 65 | from dataloaders.utils import decode_segmap 66 | import argparse 67 | 68 | parser = argparse.ArgumentParser() 69 | args = parser.parse_args() 70 | args.base_size = 513 71 | args.crop_size = 513 72 | 73 | pascal_voc_val = pascal.VOCSegmentation(args, split='val') 74 | sbd = sbd.SBDSegmentation(args, split=['train', 'val']) 75 | pascal_voc_train = pascal.VOCSegmentation(args, split='train') 76 | 77 | dataset = CombineDBs([pascal_voc_train, sbd], excluded=[pascal_voc_val]) 78 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=0) 79 | 80 | for ii, sample in enumerate(dataloader): 81 | for jj in range(sample["image"].size()[0]): 82 | img = sample['image'].numpy() 83 | gt = sample['label'].numpy() 84 | tmp = np.array(gt[jj]).astype(np.uint8) 85 | segmap = decode_segmap(tmp, dataset='pascal') 86 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 87 | img_tmp *= (0.229, 0.224, 0.225) 88 | img_tmp += (0.485, 0.456, 0.406) 89 | img_tmp *= 255.0 90 | img_tmp = img_tmp.astype(np.uint8) 91 | plt.figure() 92 | plt.title('display') 93 | plt.subplot(211) 94 | plt.imshow(img_tmp) 95 | plt.subplot(212) 96 | plt.imshow(segmap) 97 | 98 | if ii == 1: 99 | break 100 | plt.show(block=True) -------------------------------------------------------------------------------- /Segmentation/dataloaders/datasets/pascal.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from mypath import Path 7 | from torchvision import transforms 8 | from dataloaders import custom_transforms as tr 9 | 10 | class VOCSegmentation(Dataset): 11 | """ 12 | PascalVoc dataset 13 | """ 14 | NUM_CLASSES = 21 15 | 16 | def __init__(self, 17 | args, 18 | base_dir=Path.db_root_dir('pascal'), 19 | split='train', 20 | ): 21 | """ 22 | :param base_dir: path to VOC dataset directory 23 | :param split: train/val 24 | :param transform: transform to apply 25 | """ 26 | super().__init__() 27 | self._base_dir = base_dir 28 | self._image_dir = os.path.join(self._base_dir, 'JPEGImages') 29 | self._cat_dir = os.path.join(self._base_dir, 'SegmentationClass') 30 | 31 | if isinstance(split, str): 32 | self.split = [split] 33 | else: 34 | split.sort() 35 | self.split = split 36 | 37 | self.args = args 38 | 39 | _splits_dir = os.path.join(self._base_dir, 'ImageSets', 'Segmentation') 40 | 41 | self.im_ids = [] 42 | self.images = [] 43 | self.categories = [] 44 | 45 | for splt in self.split: 46 | with open(os.path.join(os.path.join(_splits_dir, splt + '.txt')), "r") as f: 47 | lines = f.read().splitlines() 48 | 49 | for ii, line in enumerate(lines): 50 | _image = os.path.join(self._image_dir, line + ".jpg") 51 | _cat = os.path.join(self._cat_dir, line + ".png") 52 | assert os.path.isfile(_image) 53 | assert os.path.isfile(_cat) 54 | self.im_ids.append(line) 55 | self.images.append(_image) 56 | self.categories.append(_cat) 57 | 58 | assert (len(self.images) == len(self.categories)) 59 | 60 | # Display stats 61 | print('Number of images in {}: {:d}'.format(split, len(self.images))) 62 | 63 | def __len__(self): 64 | return len(self.images) 65 | 66 | 67 | def __getitem__(self, index): 68 | _img, _target = self._make_img_gt_point_pair(index) 69 | sample = {'image': _img, 'label': _target} 70 | 71 | for split in self.split: 72 | if split == "train": 73 | return self.transform_tr(sample) 74 | elif split == 'val': 75 | return self.transform_val(sample) 76 | 77 | 78 | def _make_img_gt_point_pair(self, index): 79 | _img = Image.open(self.images[index]).convert('RGB') 80 | _target = Image.open(self.categories[index]) 81 | 82 | return _img, _target 83 | 84 | def transform_tr(self, sample): 85 | composed_transforms = transforms.Compose([ 86 | tr.RandomHorizontalFlip(), 87 | tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), 88 | tr.RandomGaussianBlur(), 89 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 90 | tr.ToTensor()]) 91 | 92 | return composed_transforms(sample) 93 | 94 | def transform_val(self, sample): 95 | 96 | composed_transforms = transforms.Compose([ 97 | tr.FixedResize(size=self.args.crop_size), 98 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 99 | tr.ToTensor()]) 100 | 101 | return composed_transforms(sample) 102 | 103 | def __str__(self): 104 | return 'VOC2012(split=' + str(self.split) + ')' 105 | 106 | 107 | if __name__ == '__main__': 108 | from dataloaders.utils import decode_segmap 109 | from torch.utils.data import DataLoader 110 | import matplotlib.pyplot as plt 111 | import argparse 112 | 113 | parser = argparse.ArgumentParser() 114 | args = parser.parse_args() 115 | args.base_size = 513 116 | args.crop_size = 513 117 | 118 | voc_train = VOCSegmentation(args, split='train') 119 | 120 | dataloader = DataLoader(voc_train, batch_size=5, shuffle=True, num_workers=0) 121 | 122 | for ii, sample in enumerate(dataloader): 123 | for jj in range(sample["image"].size()[0]): 124 | img = sample['image'].numpy() 125 | gt = sample['label'].numpy() 126 | tmp = np.array(gt[jj]).astype(np.uint8) 127 | segmap = decode_segmap(tmp, dataset='pascal') 128 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 129 | img_tmp *= (0.229, 0.224, 0.225) 130 | img_tmp += (0.485, 0.456, 0.406) 131 | img_tmp *= 255.0 132 | img_tmp = img_tmp.astype(np.uint8) 133 | plt.figure() 134 | plt.title('display') 135 | plt.subplot(211) 136 | plt.imshow(img_tmp) 137 | plt.subplot(212) 138 | plt.imshow(segmap) 139 | 140 | if ii == 1: 141 | break 142 | 143 | plt.show(block=True) 144 | 145 | 146 | -------------------------------------------------------------------------------- /Segmentation/dataloaders/datasets/sbd.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | 4 | import numpy as np 5 | import scipy.io 6 | import torch.utils.data as data 7 | from PIL import Image 8 | from mypath import Path 9 | 10 | from torchvision import transforms 11 | from dataloaders import custom_transforms as tr 12 | 13 | class SBDSegmentation(data.Dataset): 14 | NUM_CLASSES = 21 15 | 16 | def __init__(self, 17 | args, 18 | base_dir=Path.db_root_dir('sbd'), 19 | split='train', 20 | ): 21 | """ 22 | :param base_dir: path to VOC dataset directory 23 | :param split: train/val 24 | :param transform: transform to apply 25 | """ 26 | super().__init__() 27 | self._base_dir = base_dir 28 | self._dataset_dir = os.path.join(self._base_dir, 'dataset') 29 | self._image_dir = os.path.join(self._dataset_dir, 'img') 30 | self._cat_dir = os.path.join(self._dataset_dir, 'cls') 31 | 32 | 33 | if isinstance(split, str): 34 | self.split = [split] 35 | else: 36 | split.sort() 37 | self.split = split 38 | 39 | self.args = args 40 | 41 | # Get list of all images from the split and check that the files exist 42 | self.im_ids = [] 43 | self.images = [] 44 | self.categories = [] 45 | for splt in self.split: 46 | with open(os.path.join(self._dataset_dir, splt + '.txt'), "r") as f: 47 | lines = f.read().splitlines() 48 | 49 | for line in lines: 50 | _image = os.path.join(self._image_dir, line + ".jpg") 51 | _categ= os.path.join(self._cat_dir, line + ".mat") 52 | assert os.path.isfile(_image) 53 | assert os.path.isfile(_categ) 54 | self.im_ids.append(line) 55 | self.images.append(_image) 56 | self.categories.append(_categ) 57 | 58 | assert (len(self.images) == len(self.categories)) 59 | 60 | # Display stats 61 | print('Number of images: {:d}'.format(len(self.images))) 62 | 63 | 64 | def __getitem__(self, index): 65 | _img, _target = self._make_img_gt_point_pair(index) 66 | sample = {'image': _img, 'label': _target} 67 | 68 | return self.transform(sample) 69 | 70 | def __len__(self): 71 | return len(self.images) 72 | 73 | def _make_img_gt_point_pair(self, index): 74 | _img = Image.open(self.images[index]).convert('RGB') 75 | _target = Image.fromarray(scipy.io.loadmat(self.categories[index])["GTcls"][0]['Segmentation'][0]) 76 | 77 | return _img, _target 78 | 79 | def transform(self, sample): 80 | composed_transforms = transforms.Compose([ 81 | tr.RandomHorizontalFlip(), 82 | tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), 83 | tr.RandomGaussianBlur(), 84 | tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 85 | tr.ToTensor()]) 86 | 87 | return composed_transforms(sample) 88 | 89 | 90 | def __str__(self): 91 | return 'SBDSegmentation(split=' + str(self.split) + ')' 92 | 93 | 94 | if __name__ == '__main__': 95 | from dataloaders.utils import decode_segmap 96 | from torch.utils.data import DataLoader 97 | import matplotlib.pyplot as plt 98 | import argparse 99 | 100 | parser = argparse.ArgumentParser() 101 | args = parser.parse_args() 102 | args.base_size = 513 103 | args.crop_size = 513 104 | 105 | sbd_train = SBDSegmentation(args, split='train') 106 | dataloader = DataLoader(sbd_train, batch_size=2, shuffle=True, num_workers=2) 107 | 108 | for ii, sample in enumerate(dataloader): 109 | for jj in range(sample["image"].size()[0]): 110 | img = sample['image'].numpy() 111 | gt = sample['label'].numpy() 112 | tmp = np.array(gt[jj]).astype(np.uint8) 113 | segmap = decode_segmap(tmp, dataset='pascal') 114 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 115 | img_tmp *= (0.229, 0.224, 0.225) 116 | img_tmp += (0.485, 0.456, 0.406) 117 | img_tmp *= 255.0 118 | img_tmp = img_tmp.astype(np.uint8) 119 | plt.figure() 120 | plt.title('display') 121 | plt.subplot(211) 122 | plt.imshow(img_tmp) 123 | plt.subplot(212) 124 | plt.imshow(segmap) 125 | 126 | if ii == 1: 127 | break 128 | 129 | plt.show(block=True) -------------------------------------------------------------------------------- /Segmentation/dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | 5 | def decode_seg_map_sequence(label_masks, dataset='pascal'): 6 | rgb_masks = [] 7 | for label_mask in label_masks: 8 | rgb_mask = decode_segmap(label_mask, dataset) 9 | rgb_masks.append(rgb_mask) 10 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) 11 | return rgb_masks 12 | 13 | 14 | def decode_segmap(label_mask, dataset, plot=False): 15 | """Decode segmentation class labels into a color image 16 | Args: 17 | label_mask (np.ndarray): an (M,N) array of integer values denoting 18 | the class label at each spatial location. 19 | plot (bool, optional): whether to show the resulting color image 20 | in a figure. 21 | Returns: 22 | (np.ndarray, optional): the resulting decoded color image. 23 | """ 24 | if dataset == 'pascal' or dataset == 'coco': 25 | n_classes = 21 26 | label_colours = get_pascal_labels() 27 | elif dataset == 'cityscapes': 28 | n_classes = 19 29 | label_colours = get_cityscapes_labels() 30 | else: 31 | raise NotImplementedError 32 | 33 | r = label_mask.copy() 34 | g = label_mask.copy() 35 | b = label_mask.copy() 36 | for ll in range(0, n_classes): 37 | r[label_mask == ll] = label_colours[ll, 0] 38 | g[label_mask == ll] = label_colours[ll, 1] 39 | b[label_mask == ll] = label_colours[ll, 2] 40 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 41 | rgb[:, :, 0] = r / 255.0 42 | rgb[:, :, 1] = g / 255.0 43 | rgb[:, :, 2] = b / 255.0 44 | if plot: 45 | plt.imshow(rgb) 46 | plt.show() 47 | else: 48 | return rgb 49 | 50 | 51 | def encode_segmap(mask): 52 | """Encode segmentation label images as pascal classes 53 | Args: 54 | mask (np.ndarray): raw segmentation label image of dimension 55 | (M, N, 3), in which the Pascal classes are encoded as colours. 56 | Returns: 57 | (np.ndarray): class map with dimensions (M,N), where the value at 58 | a given location is the integer denoting the class index. 59 | """ 60 | mask = mask.astype(int) 61 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) 62 | for ii, label in enumerate(get_pascal_labels()): 63 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 64 | label_mask = label_mask.astype(int) 65 | return label_mask 66 | 67 | 68 | def get_cityscapes_labels(): 69 | return np.array([ 70 | [128, 64, 128], 71 | [244, 35, 232], 72 | [70, 70, 70], 73 | [102, 102, 156], 74 | [190, 153, 153], 75 | [153, 153, 153], 76 | [250, 170, 30], 77 | [220, 220, 0], 78 | [107, 142, 35], 79 | [152, 251, 152], 80 | [0, 130, 180], 81 | [220, 20, 60], 82 | [255, 0, 0], 83 | [0, 0, 142], 84 | [0, 0, 70], 85 | [0, 60, 100], 86 | [0, 80, 100], 87 | [0, 0, 230], 88 | [119, 11, 32]]) 89 | 90 | 91 | def get_pascal_labels(): 92 | """Load the mapping that associates pascal classes with label colors 93 | Returns: 94 | np.ndarray with dimensions (21, 3) 95 | """ 96 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 97 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 98 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 99 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 100 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 101 | [0, 64, 128]]) -------------------------------------------------------------------------------- /Segmentation/distiller.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from scipy.stats import norm 5 | import scipy 6 | 7 | import math 8 | 9 | def distillation_loss(source, target, margin): 10 | target = torch.max(target, margin) 11 | loss = torch.nn.functional.mse_loss(source, target, reduction="none") 12 | loss = loss * ((source > target) | (target > 0)).float() 13 | return loss.sum() 14 | 15 | def build_feature_connector(t_channel, s_channel): 16 | C = [nn.Conv2d(s_channel, t_channel, kernel_size=1, stride=1, padding=0, bias=False), 17 | nn.BatchNorm2d(t_channel)] 18 | 19 | for m in C: 20 | if isinstance(m, nn.Conv2d): 21 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 22 | m.weight.data.normal_(0, math.sqrt(2. / n)) 23 | elif isinstance(m, nn.BatchNorm2d): 24 | m.weight.data.fill_(1) 25 | m.bias.data.zero_() 26 | 27 | return nn.Sequential(*C) 28 | 29 | def get_margin_from_BN(bn): 30 | margin = [] 31 | std = bn.weight.data 32 | mean = bn.bias.data 33 | for (s, m) in zip(std, mean): 34 | s = abs(s.item()) 35 | m = m.item() 36 | if norm.cdf(-m / s) > 0.001: 37 | margin.append(- s * math.exp(- (m / s) ** 2 / 2) / math.sqrt(2 * math.pi) / norm.cdf(-m / s) + m) 38 | else: 39 | margin.append(-3 * s) 40 | 41 | return torch.FloatTensor(margin).to(std.device) 42 | 43 | class Distiller(nn.Module): 44 | def __init__(self, t_net, s_net): 45 | super(Distiller, self).__init__() 46 | 47 | t_channels = t_net.get_channel_num() 48 | s_channels = s_net.get_channel_num() 49 | 50 | self.Connectors = nn.ModuleList([build_feature_connector(t, s) for t, s in zip(t_channels, s_channels)]) 51 | 52 | teacher_bns = t_net.get_bn_before_relu() 53 | margins = [get_margin_from_BN(bn) for bn in teacher_bns] 54 | for i, margin in enumerate(margins): 55 | self.register_buffer('margin%d' % (i+1), margin.unsqueeze(1).unsqueeze(2).unsqueeze(0).detach()) 56 | 57 | self.t_net = t_net 58 | self.s_net = s_net 59 | 60 | self.loss_divider = [8, 4, 2, 1, 1, 4*4] 61 | 62 | def forward(self, x): 63 | 64 | t_feats, t_out = self.t_net.extract_feature(x) 65 | s_feats, s_out = self.s_net.extract_feature(x) 66 | feat_num = len(t_feats) 67 | 68 | loss_distill = 0 69 | for i in range(feat_num): 70 | s_feats[i] = self.Connectors[i](s_feats[i]) 71 | loss_distill += distillation_loss(s_feats[i], t_feats[i].detach(), getattr(self, 'margin%d' % (i+1))) \ 72 | / self.loss_divider[i] 73 | 74 | return s_out, loss_distill 75 | -------------------------------------------------------------------------------- /Segmentation/modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/overhaul-distillation/76344a84a7ce23c894f41a2e05b866c9b73fd85a/Segmentation/modeling/__init__.py -------------------------------------------------------------------------------- /Segmentation/modeling/aspp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class _ASPPModule(nn.Module): 8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 9 | super(_ASPPModule, self).__init__() 10 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 11 | stride=1, padding=padding, dilation=dilation, bias=False) 12 | self.bn = BatchNorm(planes) 13 | self.relu = nn.ReLU() 14 | 15 | self._init_weight() 16 | 17 | def forward(self, x): 18 | x = self.atrous_conv(x) 19 | x = self.bn(x) 20 | 21 | return self.relu(x) 22 | 23 | def _init_weight(self): 24 | for m in self.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | torch.nn.init.kaiming_normal_(m.weight) 27 | elif isinstance(m, SynchronizedBatchNorm2d): 28 | m.weight.data.fill_(1) 29 | m.bias.data.zero_() 30 | elif isinstance(m, nn.BatchNorm2d): 31 | m.weight.data.fill_(1) 32 | m.bias.data.zero_() 33 | 34 | class ASPP(nn.Module): 35 | def __init__(self, backbone, output_stride, BatchNorm): 36 | super(ASPP, self).__init__() 37 | if backbone == 'drn': 38 | inplanes = 512 39 | elif backbone == 'mobilenet': 40 | inplanes = 320 41 | elif backbone == 'resnet18': 42 | inplanes = 512 43 | else: 44 | inplanes = 2048 45 | if output_stride == 16: 46 | dilations = [1, 6, 12, 18] 47 | elif output_stride == 8: 48 | dilations = [1, 12, 24, 36] 49 | else: 50 | raise NotImplementedError 51 | 52 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) 53 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) 54 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) 55 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) 56 | 57 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 58 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 59 | BatchNorm(256), 60 | nn.ReLU()) 61 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 62 | self.bn1 = BatchNorm(256) 63 | self.relu = nn.ReLU() 64 | self.dropout = nn.Dropout(0.5) 65 | self._init_weight() 66 | 67 | def forward(self, x): 68 | x1 = self.aspp1(x) 69 | x2 = self.aspp2(x) 70 | x3 = self.aspp3(x) 71 | x4 = self.aspp4(x) 72 | x5 = self.global_avg_pool(x) 73 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 74 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 75 | 76 | x = self.conv1(x) 77 | x = self.bn1(x) 78 | x = self.relu(x) 79 | 80 | return self.dropout(x) 81 | 82 | def get_bn_before_relu(self): 83 | return [self.bn1] 84 | 85 | def get_channel_num(self): 86 | return [256] 87 | 88 | 89 | def extract_feature(self, x): 90 | x1 = self.aspp1(x) 91 | x2 = self.aspp2(x) 92 | x3 = self.aspp3(x) 93 | x4 = self.aspp4(x) 94 | x5 = self.global_avg_pool(x) 95 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 96 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 97 | 98 | x = self.conv1(x) 99 | x = self.bn1(x) 100 | feat1 = x 101 | x = self.relu(x) 102 | 103 | return [feat1], x 104 | 105 | def _init_weight(self): 106 | for m in self.modules(): 107 | if isinstance(m, nn.Conv2d): 108 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 109 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 110 | torch.nn.init.kaiming_normal_(m.weight) 111 | elif isinstance(m, SynchronizedBatchNorm2d): 112 | m.weight.data.fill_(1) 113 | m.bias.data.zero_() 114 | elif isinstance(m, nn.BatchNorm2d): 115 | m.weight.data.fill_(1) 116 | m.bias.data.zero_() 117 | 118 | 119 | def build_aspp(backbone, output_stride, BatchNorm): 120 | return ASPP(backbone, output_stride, BatchNorm) -------------------------------------------------------------------------------- /Segmentation/modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from modeling.backbone import resnet, xception, drn, mobilenet 2 | 3 | def build_backbone(backbone, output_stride, BatchNorm): 4 | if backbone == 'resnet101': 5 | return resnet.ResNet101(output_stride, BatchNorm) 6 | elif backbone == 'xception': 7 | return xception.AlignedXception(output_stride, BatchNorm) 8 | elif backbone == 'drn': 9 | return drn.drn_d_54(BatchNorm) 10 | elif backbone == 'mobilenet': 11 | return mobilenet.MobileNetV2(output_stride, BatchNorm) 12 | elif backbone == 'resnet18': 13 | return resnet.ResNet18(output_stride, BatchNorm) 14 | else: 15 | raise NotImplementedError 16 | -------------------------------------------------------------------------------- /Segmentation/modeling/backbone/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import math 5 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | import torch.utils.model_zoo as model_zoo 7 | 8 | def conv_bn(inp, oup, stride, BatchNorm): 9 | return nn.Sequential( 10 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 11 | BatchNorm(oup), 12 | nn.ReLU6(inplace=False) 13 | ) 14 | 15 | 16 | def fixed_padding(inputs, kernel_size, dilation): 17 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 18 | pad_total = kernel_size_effective - 1 19 | pad_beg = pad_total // 2 20 | pad_end = pad_total - pad_beg 21 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 22 | return padded_inputs 23 | 24 | 25 | class InvertedResidual(nn.Module): 26 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm): 27 | super(InvertedResidual, self).__init__() 28 | self.stride = stride 29 | assert stride in [1, 2] 30 | 31 | hidden_dim = round(inp * expand_ratio) 32 | self.use_res_connect = self.stride == 1 and inp == oup 33 | self.kernel_size = 3 34 | self.dilation = dilation 35 | 36 | if expand_ratio == 1: 37 | self.conv = nn.Sequential( 38 | # dw 39 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 40 | BatchNorm(hidden_dim), 41 | nn.ReLU6(inplace=False), 42 | # pw-linear 43 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False), 44 | BatchNorm(oup), 45 | ) 46 | else: 47 | self.conv = nn.Sequential( 48 | # pw 49 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False), 50 | BatchNorm(hidden_dim), 51 | nn.ReLU6(inplace=False), 52 | # dw 53 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 54 | BatchNorm(hidden_dim), 55 | nn.ReLU6(inplace=False), 56 | # pw-linear 57 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False), 58 | BatchNorm(oup), 59 | ) 60 | 61 | def forward(self, x): 62 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation) 63 | if self.use_res_connect: 64 | x = x + self.conv(x_pad) 65 | else: 66 | x = self.conv(x_pad) 67 | return x 68 | 69 | 70 | class MobileNetV2(nn.Module): 71 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=True): 72 | super(MobileNetV2, self).__init__() 73 | block = InvertedResidual 74 | input_channel = 32 75 | current_stride = 1 76 | rate = 1 77 | interverted_residual_setting = [ 78 | # t, c, n, s 79 | [1, 16, 1, 1], 80 | [6, 24, 2, 2], 81 | [6, 32, 3, 2], 82 | [6, 64, 4, 2], 83 | [6, 96, 3, 1], 84 | [6, 160, 3, 2], 85 | [6, 320, 1, 1], 86 | ] 87 | 88 | # building first layer 89 | input_channel = int(input_channel * width_mult) 90 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)] 91 | current_stride *= 2 92 | # building inverted residual blocks 93 | for t, c, n, s in interverted_residual_setting: 94 | if current_stride == output_stride: 95 | stride = 1 96 | dilation = rate 97 | rate *= s 98 | else: 99 | stride = s 100 | dilation = 1 101 | current_stride *= s 102 | output_channel = int(c * width_mult) 103 | for i in range(n): 104 | if i == 0: 105 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm)) 106 | else: 107 | self.features.append(block(input_channel, output_channel, 1, dilation, t, BatchNorm)) 108 | input_channel = output_channel 109 | self.features = nn.Sequential(*self.features) 110 | self._initialize_weights() 111 | 112 | if pretrained: 113 | self._load_pretrained_model() 114 | 115 | self.low_level_features = self.features[0:4] 116 | self.high_level_features = self.features[4:] 117 | 118 | def forward(self, x): 119 | low_level_feat = self.low_level_features(x) 120 | x = self.high_level_features(low_level_feat) 121 | return x, low_level_feat 122 | 123 | def get_bn_before_relu(self): 124 | bn1 = self.features[4].conv[1] 125 | bn2 = self.features[7].conv[1] 126 | bn3 = self.features[14].conv[1] 127 | bn4 = self.features[-1].conv[-1] 128 | 129 | return [bn1, bn2, bn3, bn4] 130 | 131 | def get_channel_num(self): 132 | return [144, 192, 576, 320] 133 | 134 | 135 | def extract_feature(self, x): 136 | 137 | feat1 = self.features[0:4](x) 138 | low_level_feat = feat1 139 | feat2 = self.features[4:7](feat1) 140 | feat3 = self.features[7:14](feat2) 141 | feat4 = self.features[14:](feat3) 142 | out = feat4 143 | 144 | # preReLU 145 | feat1 = self.features[4].conv[0:2](feat1) 146 | feat2 = self.features[7].conv[0:2](feat2) 147 | feat3 = self.features[14].conv[0:2](feat3) 148 | 149 | return [feat1, feat2, feat3, feat4], out, low_level_feat 150 | 151 | def _load_pretrained_model(self): 152 | pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth') 153 | model_dict = {} 154 | state_dict = self.state_dict() 155 | for k, v in pretrain_dict.items(): 156 | if k in state_dict: 157 | model_dict[k] = v 158 | state_dict.update(model_dict) 159 | self.load_state_dict(state_dict) 160 | 161 | def _initialize_weights(self): 162 | for m in self.modules(): 163 | if isinstance(m, nn.Conv2d): 164 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 165 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 166 | torch.nn.init.kaiming_normal_(m.weight) 167 | elif isinstance(m, SynchronizedBatchNorm2d): 168 | m.weight.data.fill_(1) 169 | m.bias.data.zero_() 170 | elif isinstance(m, nn.BatchNorm2d): 171 | m.weight.data.fill_(1) 172 | m.bias.data.zero_() 173 | 174 | if __name__ == "__main__": 175 | input = torch.rand(1, 3, 512, 512) 176 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d) 177 | output, low_level_feat = model(input) 178 | print(output.size()) 179 | print(low_level_feat.size()) 180 | -------------------------------------------------------------------------------- /Segmentation/modeling/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | import torch.nn.functional as F 6 | 7 | class BasicBlock(nn.Module): 8 | expansion = 1 9 | 10 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 11 | super(BasicBlock, self).__init__() 12 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, dilation=dilation, padding=dilation, bias=False) 13 | self.bn1 = BatchNorm(planes) 14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, dilation=dilation, padding=dilation, bias=False) 15 | self.bn2 = BatchNorm(planes) 16 | self.relu = nn.ReLU(inplace=True) 17 | 18 | self.downsample = downsample 19 | self.stride = stride 20 | 21 | def forward(self, x): 22 | x = F.relu(x) 23 | residual = x 24 | 25 | out = self.conv1(x) 26 | out = self.bn1(out) 27 | out = self.relu(out) 28 | 29 | out = self.conv2(out) 30 | out = self.bn2(out) 31 | 32 | if self.downsample is not None: 33 | residual = self.downsample(x) 34 | 35 | out += residual 36 | # out = self.relu(out) 37 | 38 | return out 39 | 40 | class Bottleneck(nn.Module): 41 | expansion = 4 42 | 43 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 44 | super(Bottleneck, self).__init__() 45 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 46 | self.bn1 = BatchNorm(planes) 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 48 | dilation=dilation, padding=dilation, bias=False) 49 | self.bn2 = BatchNorm(planes) 50 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 51 | self.bn3 = BatchNorm(planes * 4) 52 | self.relu = nn.ReLU(inplace=True) 53 | self.downsample = downsample 54 | self.stride = stride 55 | self.dilation = dilation 56 | 57 | def forward(self, x): 58 | x = F.relu(x) 59 | residual = x 60 | 61 | out = self.conv1(x) 62 | out = self.bn1(out) 63 | out = self.relu(out) 64 | 65 | out = self.conv2(out) 66 | out = self.bn2(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv3(out) 70 | out = self.bn3(out) 71 | 72 | if self.downsample is not None: 73 | residual = self.downsample(x) 74 | 75 | out += residual 76 | # out = self.relu(out) 77 | 78 | return out 79 | 80 | class ResNet(nn.Module): 81 | 82 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=True): 83 | self.inplanes = 64 84 | super(ResNet, self).__init__() 85 | blocks = [1, 2, 4] 86 | if output_stride == 16: 87 | strides = [1, 2, 2, 1] 88 | dilations = [1, 1, 1, 2] 89 | elif output_stride == 8: 90 | strides = [1, 2, 1, 1] 91 | dilations = [1, 1, 2, 4] 92 | else: 93 | raise NotImplementedError 94 | 95 | # Modules 96 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 97 | bias=False) 98 | self.bn1 = BatchNorm(64) 99 | self.relu = nn.ReLU(inplace=True) 100 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 101 | 102 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) 103 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) 104 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) 105 | if isinstance(self.layer1[0], BasicBlock): 106 | self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 107 | else: 108 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 109 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 110 | self._init_weight() 111 | 112 | if pretrained: 113 | self._load_pretrained_model() 114 | 115 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 116 | downsample = None 117 | if stride != 1 or self.inplanes != planes * block.expansion: 118 | downsample = nn.Sequential( 119 | nn.Conv2d(self.inplanes, planes * block.expansion, 120 | kernel_size=1, stride=stride, bias=False), 121 | BatchNorm(planes * block.expansion), 122 | ) 123 | 124 | layers = [] 125 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 126 | self.inplanes = planes * block.expansion 127 | for i in range(1, blocks): 128 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 129 | 130 | return nn.Sequential(*layers) 131 | 132 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 133 | downsample = None 134 | if stride != 1 or self.inplanes != planes * block.expansion: 135 | downsample = nn.Sequential( 136 | nn.Conv2d(self.inplanes, planes * block.expansion, 137 | kernel_size=1, stride=stride, bias=False), 138 | BatchNorm(planes * block.expansion), 139 | ) 140 | 141 | layers = [] 142 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, 143 | downsample=downsample, BatchNorm=BatchNorm)) 144 | self.inplanes = planes * block.expansion 145 | for i in range(1, len(blocks)): 146 | layers.append(block(self.inplanes, planes, stride=1, 147 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm)) 148 | 149 | return nn.Sequential(*layers) 150 | 151 | def forward(self, input): 152 | x = self.conv1(input) 153 | x = self.bn1(x) 154 | x = self.relu(x) 155 | x = self.maxpool(x) 156 | 157 | x = self.layer1(x) 158 | low_level_feat = F.relu(x) 159 | x = self.layer2(x) 160 | x = self.layer3(x) 161 | x = self.layer4(x) 162 | x = F.relu(x) 163 | return x, low_level_feat 164 | 165 | def get_bn_before_relu(self): 166 | if isinstance(self.layer1[0], Bottleneck): 167 | bn1 = self.layer1[-1].bn3 168 | bn2 = self.layer2[-1].bn3 169 | bn3 = self.layer3[-1].bn3 170 | bn4 = self.layer4[-1].bn3 171 | elif isinstance(self.layer1[0], BasicBlock): 172 | bn1 = self.layer1[-1].bn2 173 | bn2 = self.layer2[-1].bn2 174 | bn3 = self.layer3[-1].bn2 175 | bn4 = self.layer4[-1].bn2 176 | else: 177 | print('ResNet unknown block error !!!') 178 | 179 | return [bn1, bn2, bn3, bn4] 180 | 181 | def get_channel_num(self): 182 | if isinstance(self.layer1[0], Bottleneck): 183 | return [256, 512, 1024, 2048] 184 | elif isinstance(self.layer1[0], BasicBlock): 185 | return [64, 128, 256, 512] 186 | 187 | 188 | def extract_feature(self, x): 189 | 190 | x = self.conv1(x) 191 | x = self.bn1(x) 192 | x = self.relu(x) 193 | x = self.maxpool(x) 194 | 195 | feat1 = self.layer1(x) 196 | low_level_feat = F.relu(feat1) 197 | feat2 = self.layer2(feat1) 198 | feat3 = self.layer3(feat2) 199 | feat4 = self.layer4(feat3) 200 | out = F.relu(feat4) 201 | 202 | return [feat1, feat2, feat3, feat4], out, low_level_feat 203 | 204 | def _init_weight(self): 205 | for m in self.modules(): 206 | if isinstance(m, nn.Conv2d): 207 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 208 | m.weight.data.normal_(0, math.sqrt(2. / n)) 209 | elif isinstance(m, SynchronizedBatchNorm2d): 210 | m.weight.data.fill_(1) 211 | m.bias.data.zero_() 212 | elif isinstance(m, nn.BatchNorm2d): 213 | m.weight.data.fill_(1) 214 | m.bias.data.zero_() 215 | 216 | def _load_pretrained_model(self): 217 | if isinstance(self.layer1[0], BasicBlock): 218 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet18-5c106cde.pth') 219 | else: 220 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth') 221 | model_dict = {} 222 | state_dict = self.state_dict() 223 | for k, v in pretrain_dict.items(): 224 | if k in state_dict: 225 | model_dict[k] = v 226 | state_dict.update(model_dict) 227 | self.load_state_dict(state_dict) 228 | 229 | def ResNet101(output_stride, BatchNorm, pretrained=True): 230 | """Constructs a ResNet-101 model. 231 | Args: 232 | pretrained (bool): If True, returns a model pre-trained on ImageNet 233 | """ 234 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained) 235 | return model 236 | 237 | import torchvision.models.resnet 238 | def ResNet18(output_stride, BatchNorm, pretrained=True): 239 | """Constructs a ResNet-18 model. 240 | Args: 241 | pretrained (bool): If True, returns a model pre-trained on ImageNet 242 | """ 243 | model = ResNet(BasicBlock, [2, 2, 2, 2], output_stride, BatchNorm, pretrained=pretrained) 244 | return model 245 | -------------------------------------------------------------------------------- /Segmentation/modeling/backbone/xception.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 7 | 8 | def fixed_padding(inputs, kernel_size, dilation): 9 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 10 | pad_total = kernel_size_effective - 1 11 | pad_beg = pad_total // 2 12 | pad_end = pad_total - pad_beg 13 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 14 | return padded_inputs 15 | 16 | 17 | class SeparableConv2d(nn.Module): 18 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, BatchNorm=None): 19 | super(SeparableConv2d, self).__init__() 20 | 21 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, 0, dilation, 22 | groups=inplanes, bias=bias) 23 | self.bn = BatchNorm(inplanes) 24 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 25 | 26 | def forward(self, x): 27 | x = fixed_padding(x, self.conv1.kernel_size[0], dilation=self.conv1.dilation[0]) 28 | x = self.conv1(x) 29 | x = self.bn(x) 30 | x = self.pointwise(x) 31 | return x 32 | 33 | 34 | class Block(nn.Module): 35 | def __init__(self, inplanes, planes, reps, stride=1, dilation=1, BatchNorm=None, 36 | start_with_relu=True, grow_first=True, is_last=False): 37 | super(Block, self).__init__() 38 | 39 | if planes != inplanes or stride != 1: 40 | self.skip = nn.Conv2d(inplanes, planes, 1, stride=stride, bias=False) 41 | self.skipbn = BatchNorm(planes) 42 | else: 43 | self.skip = None 44 | 45 | self.relu = nn.ReLU(inplace=True) 46 | rep = [] 47 | 48 | filters = inplanes 49 | if grow_first: 50 | rep.append(self.relu) 51 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 52 | rep.append(BatchNorm(planes)) 53 | filters = planes 54 | 55 | for i in range(reps - 1): 56 | rep.append(self.relu) 57 | rep.append(SeparableConv2d(filters, filters, 3, 1, dilation, BatchNorm=BatchNorm)) 58 | rep.append(BatchNorm(filters)) 59 | 60 | if not grow_first: 61 | rep.append(self.relu) 62 | rep.append(SeparableConv2d(inplanes, planes, 3, 1, dilation, BatchNorm=BatchNorm)) 63 | rep.append(BatchNorm(planes)) 64 | 65 | if stride != 1: 66 | rep.append(self.relu) 67 | rep.append(SeparableConv2d(planes, planes, 3, 2, BatchNorm=BatchNorm)) 68 | rep.append(BatchNorm(planes)) 69 | 70 | if stride == 1 and is_last: 71 | rep.append(self.relu) 72 | rep.append(SeparableConv2d(planes, planes, 3, 1, BatchNorm=BatchNorm)) 73 | rep.append(BatchNorm(planes)) 74 | 75 | if not start_with_relu: 76 | rep = rep[1:] 77 | 78 | self.rep = nn.Sequential(*rep) 79 | 80 | def forward(self, inp): 81 | x = self.rep(inp) 82 | 83 | if self.skip is not None: 84 | skip = self.skip(inp) 85 | skip = self.skipbn(skip) 86 | else: 87 | skip = inp 88 | 89 | x = x + skip 90 | 91 | return x 92 | 93 | 94 | class AlignedXception(nn.Module): 95 | """ 96 | Modified Alighed Xception 97 | """ 98 | def __init__(self, output_stride, BatchNorm, 99 | pretrained=True): 100 | super(AlignedXception, self).__init__() 101 | 102 | if output_stride == 16: 103 | entry_block3_stride = 2 104 | middle_block_dilation = 1 105 | exit_block_dilations = (1, 2) 106 | elif output_stride == 8: 107 | entry_block3_stride = 1 108 | middle_block_dilation = 2 109 | exit_block_dilations = (2, 4) 110 | else: 111 | raise NotImplementedError 112 | 113 | 114 | # Entry flow 115 | self.conv1 = nn.Conv2d(3, 32, 3, stride=2, padding=1, bias=False) 116 | self.bn1 = BatchNorm(32) 117 | self.relu = nn.ReLU(inplace=True) 118 | 119 | self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False) 120 | self.bn2 = BatchNorm(64) 121 | 122 | self.block1 = Block(64, 128, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False) 123 | self.block2 = Block(128, 256, reps=2, stride=2, BatchNorm=BatchNorm, start_with_relu=False, 124 | grow_first=True) 125 | self.block3 = Block(256, 728, reps=2, stride=entry_block3_stride, BatchNorm=BatchNorm, 126 | start_with_relu=True, grow_first=True, is_last=True) 127 | 128 | # Middle flow 129 | self.block4 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 130 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 131 | self.block5 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 132 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 133 | self.block6 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 134 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 135 | self.block7 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 136 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 137 | self.block8 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 138 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 139 | self.block9 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 140 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 141 | self.block10 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 142 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 143 | self.block11 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 144 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 145 | self.block12 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 146 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 147 | self.block13 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 148 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 149 | self.block14 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 150 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 151 | self.block15 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 152 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 153 | self.block16 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 154 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 155 | self.block17 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 156 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 157 | self.block18 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 158 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 159 | self.block19 = Block(728, 728, reps=3, stride=1, dilation=middle_block_dilation, 160 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=True) 161 | 162 | # Exit flow 163 | self.block20 = Block(728, 1024, reps=2, stride=1, dilation=exit_block_dilations[0], 164 | BatchNorm=BatchNorm, start_with_relu=True, grow_first=False, is_last=True) 165 | 166 | self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 167 | self.bn3 = BatchNorm(1536) 168 | 169 | self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 170 | self.bn4 = BatchNorm(1536) 171 | 172 | self.conv5 = SeparableConv2d(1536, 2048, 3, stride=1, dilation=exit_block_dilations[1], BatchNorm=BatchNorm) 173 | self.bn5 = BatchNorm(2048) 174 | 175 | # Init weights 176 | self._init_weight() 177 | 178 | # Load pretrained model 179 | if pretrained: 180 | self._load_pretrained_model() 181 | 182 | def forward(self, x): 183 | # Entry flow 184 | x = self.conv1(x) 185 | x = self.bn1(x) 186 | x = self.relu(x) 187 | 188 | x = self.conv2(x) 189 | x = self.bn2(x) 190 | x = self.relu(x) 191 | 192 | x = self.block1(x) 193 | # add relu here 194 | x = self.relu(x) 195 | low_level_feat = x 196 | x = self.block2(x) 197 | x = self.block3(x) 198 | 199 | # Middle flow 200 | x = self.block4(x) 201 | x = self.block5(x) 202 | x = self.block6(x) 203 | x = self.block7(x) 204 | x = self.block8(x) 205 | x = self.block9(x) 206 | x = self.block10(x) 207 | x = self.block11(x) 208 | x = self.block12(x) 209 | x = self.block13(x) 210 | x = self.block14(x) 211 | x = self.block15(x) 212 | x = self.block16(x) 213 | x = self.block17(x) 214 | x = self.block18(x) 215 | x = self.block19(x) 216 | 217 | # Exit flow 218 | x = self.block20(x) 219 | x = self.relu(x) 220 | x = self.conv3(x) 221 | x = self.bn3(x) 222 | x = self.relu(x) 223 | 224 | x = self.conv4(x) 225 | x = self.bn4(x) 226 | x = self.relu(x) 227 | 228 | x = self.conv5(x) 229 | x = self.bn5(x) 230 | x = self.relu(x) 231 | 232 | return x, low_level_feat 233 | 234 | def _init_weight(self): 235 | for m in self.modules(): 236 | if isinstance(m, nn.Conv2d): 237 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 238 | m.weight.data.normal_(0, math.sqrt(2. / n)) 239 | elif isinstance(m, SynchronizedBatchNorm2d): 240 | m.weight.data.fill_(1) 241 | m.bias.data.zero_() 242 | elif isinstance(m, nn.BatchNorm2d): 243 | m.weight.data.fill_(1) 244 | m.bias.data.zero_() 245 | 246 | 247 | def _load_pretrained_model(self): 248 | pretrain_dict = model_zoo.load_url('http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth') 249 | model_dict = {} 250 | state_dict = self.state_dict() 251 | 252 | for k, v in pretrain_dict.items(): 253 | if k in model_dict: 254 | if 'pointwise' in k: 255 | v = v.unsqueeze(-1).unsqueeze(-1) 256 | if k.startswith('block11'): 257 | model_dict[k] = v 258 | model_dict[k.replace('block11', 'block12')] = v 259 | model_dict[k.replace('block11', 'block13')] = v 260 | model_dict[k.replace('block11', 'block14')] = v 261 | model_dict[k.replace('block11', 'block15')] = v 262 | model_dict[k.replace('block11', 'block16')] = v 263 | model_dict[k.replace('block11', 'block17')] = v 264 | model_dict[k.replace('block11', 'block18')] = v 265 | model_dict[k.replace('block11', 'block19')] = v 266 | elif k.startswith('block12'): 267 | model_dict[k.replace('block12', 'block20')] = v 268 | elif k.startswith('bn3'): 269 | model_dict[k] = v 270 | model_dict[k.replace('bn3', 'bn4')] = v 271 | elif k.startswith('conv4'): 272 | model_dict[k.replace('conv4', 'conv5')] = v 273 | elif k.startswith('bn4'): 274 | model_dict[k.replace('bn4', 'bn5')] = v 275 | else: 276 | model_dict[k] = v 277 | state_dict.update(model_dict) 278 | self.load_state_dict(state_dict) 279 | 280 | 281 | 282 | if __name__ == "__main__": 283 | import torch 284 | model = AlignedXception(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=16) 285 | input = torch.rand(1, 3, 512, 512) 286 | output, low_level_feat = model(input) 287 | print(output.size()) 288 | print(low_level_feat.size()) -------------------------------------------------------------------------------- /Segmentation/modeling/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | class Decoder(nn.Module): 8 | def __init__(self, num_classes, backbone, BatchNorm): 9 | super(Decoder, self).__init__() 10 | if backbone == 'resnet101' or backbone == 'drn': 11 | low_level_inplanes = 256 12 | elif backbone == 'xception': 13 | low_level_inplanes = 128 14 | elif backbone == 'mobilenet': 15 | low_level_inplanes = 24 16 | elif backbone == 'resnet18': 17 | low_level_inplanes = 64 18 | else: 19 | raise NotImplementedError 20 | 21 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 22 | self.bn1 = BatchNorm(48) 23 | self.relu = nn.ReLU() 24 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 25 | BatchNorm(256), 26 | nn.ReLU(), 27 | nn.Dropout(0.5), 28 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 29 | BatchNorm(256), 30 | nn.ReLU(), 31 | nn.Dropout(0.1), 32 | nn.Conv2d(256, num_classes, kernel_size=1, stride=1)) 33 | self._init_weight() 34 | 35 | 36 | def forward(self, x, low_level_feat): 37 | low_level_feat = self.conv1(low_level_feat) 38 | low_level_feat = self.bn1(low_level_feat) 39 | low_level_feat = self.relu(low_level_feat) 40 | 41 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 42 | x = torch.cat((x, low_level_feat), dim=1) 43 | x = self.last_conv(x) 44 | 45 | return x 46 | 47 | def get_bn_before_relu(self): 48 | return [self.last_conv[5]] 49 | 50 | def get_channel_num(self): 51 | return [256] 52 | 53 | def extract_feature(self, x, low_level_feat): 54 | low_level_feat = self.conv1(low_level_feat) 55 | low_level_feat = self.bn1(low_level_feat) 56 | low_level_feat = self.relu(low_level_feat) 57 | 58 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 59 | x = torch.cat((x, low_level_feat), dim=1) 60 | x = self.last_conv[0:6](x) 61 | feat1 = x 62 | x = self.last_conv[6:](x) 63 | 64 | return [feat1], x 65 | 66 | def _init_weight(self): 67 | for m in self.modules(): 68 | if isinstance(m, nn.Conv2d): 69 | torch.nn.init.kaiming_normal_(m.weight) 70 | elif isinstance(m, SynchronizedBatchNorm2d): 71 | m.weight.data.fill_(1) 72 | m.bias.data.zero_() 73 | elif isinstance(m, nn.BatchNorm2d): 74 | m.weight.data.fill_(1) 75 | m.bias.data.zero_() 76 | 77 | def build_decoder(num_classes, backbone, BatchNorm): 78 | return Decoder(num_classes, backbone, BatchNorm) -------------------------------------------------------------------------------- /Segmentation/modeling/deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | from modeling.aspp import build_aspp 6 | from modeling.decoder import build_decoder 7 | from modeling.backbone import build_backbone 8 | 9 | class DeepLab(nn.Module): 10 | def __init__(self, backbone='resnet', output_stride=16, num_classes=21, 11 | sync_bn=True, freeze_bn=False): 12 | super(DeepLab, self).__init__() 13 | if backbone == 'drn': 14 | output_stride = 8 15 | 16 | if sync_bn == True: 17 | BatchNorm = SynchronizedBatchNorm2d 18 | else: 19 | BatchNorm = nn.BatchNorm2d 20 | 21 | self.backbone = build_backbone(backbone, output_stride, BatchNorm) 22 | self.aspp = build_aspp(backbone, output_stride, BatchNorm) 23 | self.decoder = build_decoder(num_classes, backbone, BatchNorm) 24 | 25 | if freeze_bn: 26 | self.freeze_bn() 27 | 28 | def forward(self, input): 29 | x, low_level_feat = self.backbone(input) 30 | x = self.aspp(x) 31 | x = self.decoder(x, low_level_feat) 32 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) 33 | 34 | return x 35 | 36 | def freeze_bn(self): 37 | for m in self.modules(): 38 | if isinstance(m, SynchronizedBatchNorm2d): 39 | m.eval() 40 | elif isinstance(m, nn.BatchNorm2d): 41 | m.eval() 42 | 43 | def get_1x_lr_params(self): 44 | modules = [self.backbone] 45 | for i in range(len(modules)): 46 | for m in modules[i].named_modules(): 47 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 48 | or isinstance(m[1], nn.BatchNorm2d): 49 | for p in m[1].parameters(): 50 | if p.requires_grad: 51 | yield p 52 | 53 | def get_10x_lr_params(self): 54 | modules = [self.aspp, self.decoder] 55 | for i in range(len(modules)): 56 | for m in modules[i].named_modules(): 57 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \ 58 | or isinstance(m[1], nn.BatchNorm2d): 59 | for p in m[1].parameters(): 60 | if p.requires_grad: 61 | yield p 62 | 63 | def get_bn_before_relu(self): 64 | BNs = self.backbone.get_bn_before_relu() 65 | BNs += self.aspp.get_bn_before_relu() 66 | BNs += self.decoder.get_bn_before_relu() 67 | 68 | return BNs 69 | 70 | def get_channel_num(self): 71 | channels = self.backbone.get_channel_num() 72 | channels += self.aspp.get_channel_num() 73 | channels += self.decoder.get_channel_num() 74 | 75 | return channels 76 | 77 | def extract_feature(self, input): 78 | feats, x, low_level_feat = self.backbone.extract_feature(input) 79 | feat, x = self.aspp.extract_feature(x) 80 | feats += feat 81 | feat, x = self.decoder.extract_feature(x, low_level_feat) 82 | feats += feat 83 | x = F.interpolate(x, size=input.size()[2:], mode='bilinear', align_corners=True) 84 | 85 | return feats, x 86 | 87 | -------------------------------------------------------------------------------- /Segmentation/modeling/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /Segmentation/modeling/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | .. math:: 132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 134 | standard-deviation are reduced across all devices during training. 135 | For example, when one uses `nn.DataParallel` to wrap the network during 136 | training, PyTorch's implementation normalize the tensor on each device using 137 | the statistics only on that device, which accelerated the computation and 138 | is also easy to implement, but the statistics might be inaccurate. 139 | Instead, in this synchronized version, the statistics will be computed 140 | over all training samples distributed on multiple devices. 141 | 142 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 143 | as the built-in PyTorch implementation. 144 | The mean and standard-deviation are calculated per-dimension over 145 | the mini-batches and gamma and beta are learnable parameter vectors 146 | of size C (where C is the input size). 147 | During training, this layer keeps a running estimate of its computed mean 148 | and variance. The running sum is kept with a default momentum of 0.1. 149 | During evaluation, this running mean/variance is used for normalization. 150 | Because the BatchNorm is done over the `C` dimension, computing statistics 151 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 152 | Args: 153 | num_features: num_features from an expected input of size 154 | `batch_size x num_features [x width]` 155 | eps: a value added to the denominator for numerical stability. 156 | Default: 1e-5 157 | momentum: the value used for the running_mean and running_var 158 | computation. Default: 0.1 159 | affine: a boolean value that when set to ``True``, gives the layer learnable 160 | affine parameters. Default: ``True`` 161 | Shape: 162 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 163 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 164 | Examples: 165 | >>> # With Learnable Parameters 166 | >>> m = SynchronizedBatchNorm1d(100) 167 | >>> # Without Learnable Parameters 168 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 169 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 170 | >>> output = m(input) 171 | """ 172 | 173 | def _check_input_dim(self, input): 174 | if input.dim() != 2 and input.dim() != 3: 175 | raise ValueError('expected 2D or 3D input (got {}D input)' 176 | .format(input.dim())) 177 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 178 | 179 | 180 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 181 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 182 | of 3d inputs 183 | .. math:: 184 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 185 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 186 | standard-deviation are reduced across all devices during training. 187 | For example, when one uses `nn.DataParallel` to wrap the network during 188 | training, PyTorch's implementation normalize the tensor on each device using 189 | the statistics only on that device, which accelerated the computation and 190 | is also easy to implement, but the statistics might be inaccurate. 191 | Instead, in this synchronized version, the statistics will be computed 192 | over all training samples distributed on multiple devices. 193 | 194 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 195 | as the built-in PyTorch implementation. 196 | The mean and standard-deviation are calculated per-dimension over 197 | the mini-batches and gamma and beta are learnable parameter vectors 198 | of size C (where C is the input size). 199 | During training, this layer keeps a running estimate of its computed mean 200 | and variance. The running sum is kept with a default momentum of 0.1. 201 | During evaluation, this running mean/variance is used for normalization. 202 | Because the BatchNorm is done over the `C` dimension, computing statistics 203 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 204 | Args: 205 | num_features: num_features from an expected input of 206 | size batch_size x num_features x height x width 207 | eps: a value added to the denominator for numerical stability. 208 | Default: 1e-5 209 | momentum: the value used for the running_mean and running_var 210 | computation. Default: 0.1 211 | affine: a boolean value that when set to ``True``, gives the layer learnable 212 | affine parameters. Default: ``True`` 213 | Shape: 214 | - Input: :math:`(N, C, H, W)` 215 | - Output: :math:`(N, C, H, W)` (same shape as input) 216 | Examples: 217 | >>> # With Learnable Parameters 218 | >>> m = SynchronizedBatchNorm2d(100) 219 | >>> # Without Learnable Parameters 220 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 221 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 222 | >>> output = m(input) 223 | """ 224 | 225 | def _check_input_dim(self, input): 226 | if input.dim() != 4: 227 | raise ValueError('expected 4D input (got {}D input)' 228 | .format(input.dim())) 229 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 230 | 231 | 232 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 233 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 234 | of 4d inputs 235 | .. math:: 236 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 237 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 238 | standard-deviation are reduced across all devices during training. 239 | For example, when one uses `nn.DataParallel` to wrap the network during 240 | training, PyTorch's implementation normalize the tensor on each device using 241 | the statistics only on that device, which accelerated the computation and 242 | is also easy to implement, but the statistics might be inaccurate. 243 | Instead, in this synchronized version, the statistics will be computed 244 | over all training samples distributed on multiple devices. 245 | 246 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 247 | as the built-in PyTorch implementation. 248 | The mean and standard-deviation are calculated per-dimension over 249 | the mini-batches and gamma and beta are learnable parameter vectors 250 | of size C (where C is the input size). 251 | During training, this layer keeps a running estimate of its computed mean 252 | and variance. The running sum is kept with a default momentum of 0.1. 253 | During evaluation, this running mean/variance is used for normalization. 254 | Because the BatchNorm is done over the `C` dimension, computing statistics 255 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 256 | or Spatio-temporal BatchNorm 257 | Args: 258 | num_features: num_features from an expected input of 259 | size batch_size x num_features x depth x height x width 260 | eps: a value added to the denominator for numerical stability. 261 | Default: 1e-5 262 | momentum: the value used for the running_mean and running_var 263 | computation. Default: 0.1 264 | affine: a boolean value that when set to ``True``, gives the layer learnable 265 | affine parameters. Default: ``True`` 266 | Shape: 267 | - Input: :math:`(N, C, D, H, W)` 268 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 269 | Examples: 270 | >>> # With Learnable Parameters 271 | >>> m = SynchronizedBatchNorm3d(100) 272 | >>> # Without Learnable Parameters 273 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 274 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 275 | >>> output = m(input) 276 | """ 277 | 278 | def _check_input_dim(self, input): 279 | if input.dim() != 5: 280 | raise ValueError('expected 5D input (got {}D input)' 281 | .format(input.dim())) 282 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /Segmentation/modeling/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /Segmentation/modeling/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /Segmentation/modeling/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /Segmentation/mypath.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class Path(object): 4 | 5 | @staticmethod 6 | def db_root_dir(dataset): 7 | if dataset == 'pascal': 8 | return '/path/to/datasets/VOCdevkit/VOC2012/' # folder that contains VOCdevkit/. 9 | elif dataset == 'sbd': 10 | return '/path/to/datasets/benchmark_RELEASE/' # folder that contains dataset/. 11 | elif dataset == 'cityscapes': 12 | return '/path/to/datasets/cityscapes/' # foler that contains leftImg8bit/ 13 | elif dataset == 'coco': 14 | return '/path/to/datasets/coco/' 15 | else: 16 | print('Dataset {} not available.'.format(dataset)) 17 | raise NotImplementedError 18 | -------------------------------------------------------------------------------- /Segmentation/pretrained/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /Segmentation/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | from mypath import Path 7 | from dataloaders import make_data_loader 8 | from modeling.sync_batchnorm.replicate import patch_replication_callback 9 | from modeling.deeplab import * 10 | from utils.loss import SegmentationLosses 11 | from utils.calculate_weights import calculate_weigths_labels 12 | from utils.lr_scheduler import LR_Scheduler 13 | from utils.saver import Saver 14 | # from utils.summaries import TensorboardSummary 15 | from utils.metrics import Evaluator 16 | 17 | class Trainer(object): 18 | def __init__(self, args): 19 | self.args = args 20 | 21 | # Define Saver 22 | self.saver = Saver(args) 23 | self.saver.save_experiment_config() 24 | 25 | # Define Dataloader 26 | kwargs = {'num_workers': args.workers, 'pin_memory': True} 27 | self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs) 28 | 29 | # Define network 30 | model = DeepLab(num_classes=self.nclass, 31 | backbone=args.backbone, 32 | output_stride=args.out_stride, 33 | sync_bn=args.sync_bn, 34 | freeze_bn=args.freeze_bn) 35 | 36 | train_params = [{'params': model.get_1x_lr_params(), 'lr': args.lr}, 37 | {'params': model.get_10x_lr_params(), 'lr': args.lr * 10}] 38 | 39 | # Define Optimizer 40 | optimizer = torch.optim.SGD(train_params, momentum=args.momentum, 41 | weight_decay=args.weight_decay, nesterov=args.nesterov) 42 | 43 | # Define Criterion 44 | # whether to use class balanced weights 45 | if args.use_balanced_weights: 46 | classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy') 47 | if os.path.isfile(classes_weights_path): 48 | weight = np.load(classes_weights_path) 49 | else: 50 | weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass) 51 | weight = torch.from_numpy(weight.astype(np.float32)) 52 | else: 53 | weight = None 54 | self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type) 55 | self.model, self.optimizer = model, optimizer 56 | 57 | # Define Evaluator 58 | self.evaluator = Evaluator(self.nclass) 59 | # Define lr scheduler 60 | self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, 61 | args.epochs, len(self.train_loader)) 62 | 63 | # Using cuda 64 | if args.cuda: 65 | self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) 66 | patch_replication_callback(self.model) 67 | self.model = self.model.cuda() 68 | 69 | # Resuming checkpoint 70 | self.best_pred = 0.0 71 | if args.resume is not None: 72 | if not os.path.isfile(args.resume): 73 | raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) 74 | checkpoint = torch.load(args.resume) 75 | args.start_epoch = checkpoint['epoch'] 76 | if args.cuda: 77 | self.model.module.load_state_dict(checkpoint['state_dict']) 78 | else: 79 | self.model.load_state_dict(checkpoint['state_dict']) 80 | if not args.ft: 81 | self.optimizer.load_state_dict(checkpoint['optimizer']) 82 | self.best_pred = checkpoint['best_pred'] 83 | print("=> loaded checkpoint '{}' (epoch {})" 84 | .format(args.resume, checkpoint['epoch'])) 85 | 86 | # Clear start epoch if fine-tuning 87 | if args.ft: 88 | args.start_epoch = 0 89 | 90 | def training(self, epoch): 91 | train_loss = 0.0 92 | self.model.train() 93 | tbar = tqdm(self.train_loader) 94 | num_img_tr = len(self.train_loader) 95 | for i, sample in enumerate(tbar): 96 | image, target = sample['image'], sample['label'] 97 | if self.args.cuda: 98 | image, target = image.cuda(), target.cuda() 99 | self.scheduler(self.optimizer, i, epoch, self.best_pred) 100 | self.optimizer.zero_grad() 101 | output = self.model(image) 102 | loss = self.criterion(output, target) 103 | loss.backward() 104 | self.optimizer.step() 105 | train_loss += loss.item() 106 | tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) 107 | 108 | print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0])) 109 | print('Loss: %.3f' % train_loss) 110 | 111 | if self.args.no_val: 112 | # save checkpoint every epoch 113 | is_best = False 114 | self.saver.save_checkpoint({ 115 | 'epoch': epoch + 1, 116 | 'state_dict': self.model.module.state_dict(), 117 | 'optimizer': self.optimizer.state_dict(), 118 | 'best_pred': self.best_pred, 119 | }, is_best) 120 | 121 | 122 | def validation(self, epoch): 123 | self.model.eval() 124 | self.evaluator.reset() 125 | tbar = tqdm(self.val_loader, desc='\r') 126 | test_loss = 0.0 127 | for i, sample in enumerate(tbar): 128 | image, target = sample['image'], sample['label'] 129 | if self.args.cuda: 130 | image, target = image.cuda(), target.cuda() 131 | with torch.no_grad(): 132 | output = self.model(image) 133 | loss = self.criterion(output, target) 134 | test_loss += loss.item() 135 | tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1))) 136 | pred = output.data.cpu().numpy() 137 | target = target.cpu().numpy() 138 | pred = np.argmax(pred, axis=1) 139 | # Add batch sample into evaluator 140 | self.evaluator.add_batch(target, pred) 141 | 142 | # Fast test during the training 143 | Acc = self.evaluator.Pixel_Accuracy() 144 | Acc_class = self.evaluator.Pixel_Accuracy_Class() 145 | mIoU = self.evaluator.Mean_Intersection_over_Union() 146 | FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() 147 | print('Validation:') 148 | print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0])) 149 | print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU)) 150 | print('Loss: %.3f' % test_loss) 151 | 152 | new_pred = mIoU 153 | if new_pred > self.best_pred: 154 | is_best = True 155 | self.best_pred = new_pred 156 | self.saver.save_checkpoint({ 157 | 'epoch': epoch + 1, 158 | 'state_dict': self.model.module.state_dict(), 159 | 'optimizer': self.optimizer.state_dict(), 160 | 'best_pred': self.best_pred, 161 | }, is_best) 162 | 163 | def main(): 164 | parser = argparse.ArgumentParser(description="PyTorch DeeplabV3Plus Training") 165 | parser.add_argument('--backbone', type=str, default='resnet18', 166 | choices=['resnet101', 'resnet18', 'xception', 'drn', 'mobilenet'], 167 | help='backbone name (default: resnet)') 168 | parser.add_argument('--out-stride', type=int, default=16, 169 | help='network output stride (default: 16)') 170 | parser.add_argument('--dataset', type=str, default='pascal', 171 | choices=['pascal', 'coco', 'cityscapes'], 172 | help='dataset name (default: pascal)') 173 | parser.add_argument('--use-sbd', action='store_true', default=False, 174 | help='whether to use SBD dataset (default: False)') 175 | parser.add_argument('--workers', type=int, default=4, 176 | metavar='N', help='dataloader threads') 177 | parser.add_argument('--base-size', type=int, default=513, 178 | help='base image size') 179 | parser.add_argument('--crop-size', type=int, default=513, 180 | help='crop image size') 181 | parser.add_argument('--sync-bn', action='store_true', default=False, 182 | help='whether to use sync bn (default: False)') 183 | parser.add_argument('--freeze-bn', action='store_true', default=False, 184 | help='whether to freeze bn parameters (default: False)') 185 | parser.add_argument('--loss-type', type=str, default='ce', 186 | choices=['ce', 'focal'], 187 | help='loss func type (default: ce)') 188 | # training hyper params 189 | parser.add_argument('--epochs', type=int, default=None, metavar='N', 190 | help='number of epochs to train (default: auto)') 191 | parser.add_argument('--start_epoch', type=int, default=0, 192 | metavar='N', help='start epochs (default:0)') 193 | parser.add_argument('--batch-size', type=int, default=None, 194 | metavar='N', help='input batch size for \ 195 | training (default: auto)') 196 | parser.add_argument('--test-batch-size', type=int, default=None, 197 | metavar='N', help='input batch size for \ 198 | testing (default: auto)') 199 | parser.add_argument('--use-balanced-weights', action='store_true', default=False, 200 | help='whether to use balanced weights (default: False)') 201 | # optimizer params 202 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 203 | help='learning rate (default: auto)') 204 | parser.add_argument('--lr-scheduler', type=str, default='poly', 205 | choices=['poly', 'step', 'cos'], 206 | help='lr scheduler mode: (default: poly)') 207 | parser.add_argument('--momentum', type=float, default=0.9, 208 | metavar='M', help='momentum (default: 0.9)') 209 | parser.add_argument('--weight-decay', type=float, default=5e-4, 210 | metavar='M', help='w-decay (default: 5e-4)') 211 | parser.add_argument('--nesterov', action='store_true', default=False, 212 | help='whether use nesterov (default: False)') 213 | # cuda, seed and logging 214 | parser.add_argument('--no-cuda', action='store_true', default= 215 | False, help='disables CUDA training') 216 | parser.add_argument('--gpu-ids', type=str, default='0', 217 | help='use which gpu to train, must be a \ 218 | comma-separated list of integers only (default=0)') 219 | parser.add_argument('--seed', type=int, default=1, metavar='S', 220 | help='random seed (default: 1)') 221 | # checking point 222 | parser.add_argument('--resume', type=str, default=None, 223 | help='put the path to resuming file if needed') 224 | parser.add_argument('--checkname', type=str, default=None, 225 | help='set the checkpoint name') 226 | # finetuning pre-trained models 227 | parser.add_argument('--ft', action='store_true', default=False, 228 | help='finetuning on a different dataset') 229 | # evaluation option 230 | parser.add_argument('--eval-interval', type=int, default=1, 231 | help='evaluuation interval (default: 1)') 232 | parser.add_argument('--no-val', action='store_true', default=False, 233 | help='skip validation during training') 234 | 235 | args = parser.parse_args() 236 | args.cuda = not args.no_cuda and torch.cuda.is_available() 237 | if args.cuda: 238 | try: 239 | args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')] 240 | except ValueError: 241 | raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only') 242 | 243 | # default settings for epochs, batch_size and lr 244 | if args.epochs is None: 245 | epoches = { 246 | 'coco': 30, 247 | 'cityscapes': 200, 248 | 'pascal': 50, 249 | } 250 | args.epochs = epoches[args.dataset.lower()] 251 | 252 | if args.batch_size is None: 253 | args.batch_size = 6 * len(args.gpu_ids) 254 | 255 | if args.test_batch_size is None: 256 | args.test_batch_size = args.batch_size 257 | 258 | if args.lr is None: 259 | lrs = { 260 | 'coco': 0.1, 261 | 'cityscapes': 0.01, 262 | 'pascal': 0.007, 263 | } 264 | args.lr = lrs[args.dataset.lower()] / (6 * len(args.gpu_ids)) * args.batch_size 265 | 266 | 267 | if args.checkname is None: 268 | args.checkname = 'deeplab-'+str(args.backbone) 269 | print(args) 270 | torch.manual_seed(args.seed) 271 | trainer = Trainer(args) 272 | print('Starting Epoch:', trainer.args.start_epoch) 273 | print('Total Epoches:', trainer.args.epochs) 274 | for epoch in range(trainer.args.start_epoch, trainer.args.epochs): 275 | trainer.training(epoch) 276 | if not trainer.args.no_val and epoch % args.eval_interval == (args.eval_interval - 1): 277 | trainer.validation(epoch) 278 | 279 | if __name__ == "__main__": 280 | main() 281 | -------------------------------------------------------------------------------- /Segmentation/train_voc.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py --backbone resnet --lr 0.007 --workers 4 --use-sbd True --epochs 50 --batch-size 16 --gpu-ids 0,1,2,3 --checkname deeplab-resnet --eval-interval 1 --dataset pascal 2 | -------------------------------------------------------------------------------- /Segmentation/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/overhaul-distillation/76344a84a7ce23c894f41a2e05b866c9b73fd85a/Segmentation/utils/__init__.py -------------------------------------------------------------------------------- /Segmentation/utils/calculate_weights.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import numpy as np 4 | from mypath import Path 5 | 6 | def calculate_weigths_labels(dataset, dataloader, num_classes): 7 | # Create an instance from the data loader 8 | z = np.zeros((num_classes,)) 9 | # Initialize tqdm 10 | tqdm_batch = tqdm(dataloader) 11 | print('Calculating classes weights') 12 | for sample in tqdm_batch: 13 | y = sample['label'] 14 | y = y.detach().cpu().numpy() 15 | mask = (y >= 0) & (y < num_classes) 16 | labels = y[mask].astype(np.uint8) 17 | count_l = np.bincount(labels, minlength=num_classes) 18 | z += count_l 19 | tqdm_batch.close() 20 | total_frequency = np.sum(z) 21 | class_weights = [] 22 | for frequency in z: 23 | class_weight = 1 / (np.log(1.02 + (frequency / total_frequency))) 24 | class_weights.append(class_weight) 25 | ret = np.array(class_weights) 26 | classes_weights_path = os.path.join(Path.db_root_dir(dataset), dataset+'_classes_weights.npy') 27 | np.save(classes_weights_path, ret) 28 | 29 | return ret -------------------------------------------------------------------------------- /Segmentation/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class SegmentationLosses(object): 5 | def __init__(self, weight=None, size_average=True, batch_average=True, ignore_index=255, cuda=False): 6 | self.ignore_index = ignore_index 7 | self.weight = weight 8 | self.size_average = size_average 9 | self.batch_average = batch_average 10 | self.cuda = cuda 11 | 12 | def build_loss(self, mode='ce'): 13 | """Choices: ['ce' or 'focal']""" 14 | if mode == 'ce': 15 | return self.CrossEntropyLoss 16 | elif mode == 'focal': 17 | return self.FocalLoss 18 | else: 19 | raise NotImplementedError 20 | 21 | def CrossEntropyLoss(self, logit, target): 22 | n, c, h, w = logit.size() 23 | criterion = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index, 24 | size_average=self.size_average) 25 | if self.cuda: 26 | criterion = criterion.cuda() 27 | 28 | loss = criterion(logit, target.long()) 29 | 30 | if self.batch_average: 31 | loss /= n 32 | 33 | return loss 34 | 35 | def FocalLoss(self, logit, target, gamma=2, alpha=0.5): 36 | n, c, h, w = logit.size() 37 | criterion = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index, 38 | size_average=self.size_average) 39 | if self.cuda: 40 | criterion = criterion.cuda() 41 | 42 | logpt = -criterion(logit, target.long()) 43 | pt = torch.exp(logpt) 44 | if alpha is not None: 45 | logpt *= alpha 46 | loss = -((1 - pt) ** gamma) * logpt 47 | 48 | if self.batch_average: 49 | loss /= n 50 | 51 | return loss 52 | 53 | if __name__ == "__main__": 54 | loss = SegmentationLosses(cuda=True) 55 | a = torch.rand(1, 3, 7, 7).cuda() 56 | b = torch.rand(1, 7, 7).cuda() 57 | print(loss.CrossEntropyLoss(a, b).item()) 58 | print(loss.FocalLoss(a, b, gamma=0, alpha=None).item()) 59 | print(loss.FocalLoss(a, b, gamma=2, alpha=0.5).item()) 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /Segmentation/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import math 12 | 13 | class LR_Scheduler(object): 14 | """Learning Rate Scheduler 15 | 16 | Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` 17 | 18 | Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` 19 | 20 | Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` 21 | 22 | Args: 23 | args: 24 | :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`), 25 | :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, 26 | :attr:`args.lr_step` 27 | 28 | iters_per_epoch: number of iterations per epoch 29 | """ 30 | def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0, 31 | lr_step=0, warmup_epochs=0): 32 | self.mode = mode 33 | print('Using {} LR Scheduler!'.format(self.mode)) 34 | self.lr = base_lr 35 | if mode == 'step': 36 | assert lr_step 37 | self.lr_step = lr_step 38 | self.iters_per_epoch = iters_per_epoch 39 | self.N = num_epochs * iters_per_epoch 40 | self.epoch = -1 41 | self.warmup_iters = warmup_epochs * iters_per_epoch 42 | 43 | def __call__(self, optimizer, i, epoch, best_pred): 44 | T = epoch * self.iters_per_epoch + i 45 | if self.mode == 'cos': 46 | lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi)) 47 | elif self.mode == 'poly': 48 | lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9) 49 | elif self.mode == 'step': 50 | lr = self.lr * (0.1 ** (epoch // self.lr_step)) 51 | else: 52 | raise NotImplemented 53 | # warm up lr schedule 54 | if self.warmup_iters > 0 and T < self.warmup_iters: 55 | lr = lr * 1.0 * T / self.warmup_iters 56 | if epoch > self.epoch: 57 | print('\n=>Epoches %i, learning rate = %.4f, \ 58 | previous best = %.4f' % (epoch, lr, best_pred)) 59 | self.epoch = epoch 60 | assert lr >= 0 61 | self._adjust_learning_rate(optimizer, lr) 62 | 63 | def _adjust_learning_rate(self, optimizer, lr): 64 | if len(optimizer.param_groups) == 1: 65 | optimizer.param_groups[0]['lr'] = lr 66 | else: 67 | # enlarge the lr at the head 68 | optimizer.param_groups[0]['lr'] = lr 69 | for i in range(1, len(optimizer.param_groups)): 70 | optimizer.param_groups[i]['lr'] = lr * 10 71 | -------------------------------------------------------------------------------- /Segmentation/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Evaluator(object): 5 | def __init__(self, num_class): 6 | self.num_class = num_class 7 | self.confusion_matrix = np.zeros((self.num_class,)*2) 8 | 9 | def Pixel_Accuracy(self): 10 | Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum() 11 | return Acc 12 | 13 | def Pixel_Accuracy_Class(self): 14 | Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) 15 | Acc = np.nanmean(Acc) 16 | return Acc 17 | 18 | def Mean_Intersection_over_Union(self): 19 | MIoU = np.diag(self.confusion_matrix) / ( 20 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 21 | np.diag(self.confusion_matrix)) 22 | MIoU = np.nanmean(MIoU) 23 | return MIoU 24 | 25 | def Frequency_Weighted_Intersection_over_Union(self): 26 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) 27 | iu = np.diag(self.confusion_matrix) / ( 28 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 29 | np.diag(self.confusion_matrix)) 30 | 31 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 32 | return FWIoU 33 | 34 | def _generate_matrix(self, gt_image, pre_image): 35 | mask = (gt_image >= 0) & (gt_image < self.num_class) 36 | label = self.num_class * gt_image[mask].astype('int') + pre_image[mask] 37 | count = np.bincount(label, minlength=self.num_class**2) 38 | confusion_matrix = count.reshape(self.num_class, self.num_class) 39 | return confusion_matrix 40 | 41 | def add_batch(self, gt_image, pre_image): 42 | assert gt_image.shape == pre_image.shape 43 | self.confusion_matrix += self._generate_matrix(gt_image, pre_image) 44 | 45 | def reset(self): 46 | self.confusion_matrix = np.zeros((self.num_class,) * 2) 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /Segmentation/utils/saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import torch 4 | from collections import OrderedDict 5 | import glob 6 | 7 | class Saver(object): 8 | 9 | def __init__(self, args): 10 | self.args = args 11 | self.directory = os.path.join('run', args.dataset, args.checkname) 12 | self.runs = sorted(glob.glob(os.path.join(self.directory, 'experiment_*'))) 13 | run_id = int(self.runs[-1].split('_')[-1]) + 1 if self.runs else 0 14 | 15 | self.experiment_dir = os.path.join(self.directory, 'experiment_{}'.format(str(run_id))) 16 | if not os.path.exists(self.experiment_dir): 17 | os.makedirs(self.experiment_dir) 18 | 19 | def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar'): 20 | """Saves checkpoint to disk""" 21 | filename = os.path.join(self.experiment_dir, filename) 22 | torch.save(state, filename) 23 | if is_best: 24 | best_pred = state['best_pred'] 25 | with open(os.path.join(self.experiment_dir, 'best_pred.txt'), 'w') as f: 26 | f.write(str(best_pred)) 27 | if self.runs: 28 | previous_miou = [0.0] 29 | for run in self.runs: 30 | run_id = run.split('_')[-1] 31 | path = os.path.join(self.directory, 'experiment_{}'.format(str(run_id)), 'best_pred.txt') 32 | if os.path.exists(path): 33 | with open(path, 'r') as f: 34 | miou = float(f.readline()) 35 | previous_miou.append(miou) 36 | else: 37 | continue 38 | max_miou = max(previous_miou) 39 | if best_pred > max_miou: 40 | shutil.copyfile(filename, os.path.join(self.directory, 'model_best.pth.tar')) 41 | else: 42 | shutil.copyfile(filename, os.path.join(self.directory, 'model_best.pth.tar')) 43 | 44 | def save_experiment_config(self): 45 | logfile = os.path.join(self.experiment_dir, 'parameters.txt') 46 | log_file = open(logfile, 'w') 47 | p = OrderedDict() 48 | p['datset'] = self.args.dataset 49 | p['backbone'] = self.args.backbone 50 | p['out_stride'] = self.args.out_stride 51 | p['lr'] = self.args.lr 52 | p['lr_scheduler'] = self.args.lr_scheduler 53 | p['loss_type'] = self.args.loss_type 54 | p['epoch'] = self.args.epochs 55 | p['base_size'] = self.args.base_size 56 | p['crop_size'] = self.args.crop_size 57 | 58 | for key, val in p.items(): 59 | log_file.write(key + ':' + str(val) + '\n') 60 | log_file.close() -------------------------------------------------------------------------------- /Segmentation/utils/summaries.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.utils import make_grid 4 | from tensorboardX import SummaryWriter 5 | from dataloaders.utils import decode_seg_map_sequence 6 | 7 | class TensorboardSummary(object): 8 | def __init__(self, directory): 9 | self.directory = directory 10 | 11 | def create_summary(self): 12 | writer = SummaryWriter(log_dir=os.path.join(self.directory)) 13 | return writer 14 | 15 | def visualize_image(self, writer, dataset, image, target, output, global_step): 16 | grid_image = make_grid(image[:3].clone().cpu().data, 3, normalize=True) 17 | writer.add_image('Image', grid_image, global_step) 18 | grid_image = make_grid(decode_seg_map_sequence(torch.max(output[:3], 1)[1].detach().cpu().numpy(), 19 | dataset=dataset), 3, normalize=False, range=(0, 255)) 20 | writer.add_image('Predicted label', grid_image, global_step) 21 | grid_image = make_grid(decode_seg_map_sequence(torch.squeeze(target[:3], 1).detach().cpu().numpy(), 22 | dataset=dataset), 3, normalize=False, range=(0, 255)) 23 | writer.add_image('Groundtruth label', grid_image, global_step) -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | --------------------------------------------------------------------------------