├── _config.yml ├── Hyper-parameter_1.jpg ├── Hyper-parameter_2.jpg ├── .github └── ISSUE_TEMPLATE │ ├── custom.md │ ├── feature_request.md │ └── bug_report.md ├── LICENSE ├── README.md ├── my_pooling.py ├── CUB-200-2011.py └── CUB-200-2011_ResNet18.py /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-modernist -------------------------------------------------------------------------------- /Hyper-parameter_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRIS-CV/Mutual-Channel-Loss/HEAD/Hyper-parameter_1.jpg -------------------------------------------------------------------------------- /Hyper-parameter_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRIS-CV/Mutual-Channel-Loss/HEAD/Hyper-parameter_2.jpg -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/custom.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Custom issue template 3 | about: Describe this issue template's purpose here. 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Dongliang Chang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # The Devil is in the Channels: Mutual-Channel Loss for Fine-Grained Image Classification 2 | 3 | Code release for The Devil is in the Channels: Mutual-Channel Loss for Fine-Grained Image Classification (TIP 2020) 4 | [DOI](https://doi.org/10.1109/TIP.2020.2973812 "DOI") 5 | 6 | 7 | ## Changelog 8 | - 2020/09/14 update the code: CUB-200-2011_ResNet18.py Training with ResNet18 (TRAINED FROM SCRATCH). 9 | - 2020/04/19 add the hyper-parameter fine-tune results. 10 | - 2020/04/18 clean the code for better understanding. 11 | 12 | ## Dataset 13 | ### CUB-200-2011 14 | 15 | ## Requirements 16 | 17 | - python 3.6 18 | - PyTorch 1.2.0 19 | - torchvision 20 | 21 | ## Training 22 | - Download datasets 23 | - Train: `python CUB-200-2011.py`, the alpha and beta are the hyper-parameters of the `MC-Loss` 24 | - Description : PyTorch CUB-200-2011 Training with VGG16 (TRAINED FROM SCRATCH). 25 | 26 | ## Hyper-parameter 27 | Loss = ce_loss + alpha_1 * L_dis + beta_1 * L_div 28 | ![Hyper-parameter_1](https://github.com/dongliangchang/Mutual-Channel-Loss/blob/master/Hyper-parameter_1.jpg) 29 | ![Hyper-parameter_2](https://github.com/dongliangchang/Mutual-Channel-Loss/blob/master/Hyper-parameter_2.jpg) 30 | The figure is plot by NNI. 31 | 32 | 33 | 34 | ## Other versions 35 | Other unofficial implements can be found in the following: 36 | - Kurumi233: This repo integrate the MC-Loss into a class. [code](https://github.com/Kurumi233/Mutual-Channel-Loss "code") 37 | - darcula1993: This repo implement the tf version of the MC-Loss. [code](https://github.com/darcula1993/Mutual-Channel-Loss "code") 38 | - Holocron: Implementations of recent Deep Learning tricks in Computer Vision, easily paired up with your favorite framework and model zoo. [code](https://github.com/frgfm/Holocron "code") 39 | 40 | 41 | ## Citation 42 | If you find this paper useful in your research, please consider citing: 43 | ``` 44 | @ARTICLE{9005389, 45 | author={D. {Chang} and Y. {Ding} and J. {Xie} and A. K. {Bhunia} and X. {Li} and Z. {Ma} and M. {Wu} and J. {Guo} and Y. {Song}}, 46 | journal={IEEE Transactions on Image Processing}, 47 | title={The Devil is in the Channels: Mutual-Channel Loss for Fine-Grained Image Classification}, 48 | year={2020}, volume={29}, number={}, pages={4683-4695}, 49 | doi={10.1109/TIP.2020.2973812}, 50 | ISSN={1941-0042}, 51 | month={},} 52 | ``` 53 | 54 | 55 | ## Contact 56 | Thanks for your attention! 57 | If you have any suggestion or question, you can leave a message here or contact us directly: 58 | - changdongliang@bupt.edu.cn 59 | - mazhanyu@bupt.edu.cn 60 | -------------------------------------------------------------------------------- /my_pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | from torch.autograd import Variable 5 | from torch.nn.modules.module import Module 6 | from torch.nn.modules.utils import _single, _pair, _triple 7 | import torch.nn.functional as F 8 | from torch.nn.parameter import Parameter 9 | 10 | 11 | 12 | class my_MaxPool2d(Module): 13 | 14 | 15 | def __init__(self, kernel_size, stride=None, padding=0, dilation=1, 16 | return_indices=False, ceil_mode=False): 17 | super(my_MaxPool2d, self).__init__() 18 | self.kernel_size = kernel_size 19 | self.stride = stride or kernel_size 20 | self.padding = padding 21 | self.dilation = dilation 22 | self.return_indices = return_indices 23 | self.ceil_mode = ceil_mode 24 | 25 | def forward(self, input): 26 | input = input.transpose(3,1) 27 | 28 | 29 | input = F.max_pool2d(input, self.kernel_size, self.stride, 30 | self.padding, self.dilation, self.ceil_mode, 31 | self.return_indices) 32 | input = input.transpose(3,1).contiguous() 33 | 34 | return input 35 | 36 | def __repr__(self): 37 | kh, kw = _pair(self.kernel_size) 38 | dh, dw = _pair(self.stride) 39 | padh, padw = _pair(self.padding) 40 | dilh, dilw = _pair(self.dilation) 41 | padding_str = ', padding=(' + str(padh) + ', ' + str(padw) + ')' \ 42 | if padh != 0 or padw != 0 else '' 43 | dilation_str = (', dilation=(' + str(dilh) + ', ' + str(dilw) + ')' 44 | if dilh != 0 and dilw != 0 else '') 45 | ceil_str = ', ceil_mode=' + str(self.ceil_mode) 46 | return self.__class__.__name__ + '(' \ 47 | + 'kernel_size=(' + str(kh) + ', ' + str(kw) + ')' \ 48 | + ', stride=(' + str(dh) + ', ' + str(dw) + ')' \ 49 | + padding_str + dilation_str + ceil_str + ')' 50 | 51 | 52 | class my_AvgPool2d(Module): 53 | def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False, 54 | count_include_pad=True): 55 | super(my_AvgPool2d, self).__init__() 56 | self.kernel_size = kernel_size 57 | self.stride = stride or kernel_size 58 | self.padding = padding 59 | self.ceil_mode = ceil_mode 60 | self.count_include_pad = count_include_pad 61 | 62 | def forward(self, input): 63 | input = input.transpose(3,1) 64 | input = F.avg_pool2d(input, self.kernel_size, self.stride, 65 | self.padding, self.ceil_mode, self.count_include_pad) 66 | input = input.transpose(3,1).contiguous() 67 | 68 | return input 69 | 70 | 71 | def __repr__(self): 72 | return self.__class__.__name__ + '(' \ 73 | + 'kernel_size=' + str(self.kernel_size) \ 74 | + ', stride=' + str(self.stride) \ 75 | + ', padding=' + str(self.padding) \ 76 | + ', ceil_mode=' + str(self.ceil_mode) \ 77 | + ', count_include_pad=' + str(self.count_include_pad) + ')' 78 | 79 | 80 | m = my_MaxPool2d((1, 32), stride=(1, 32)) 81 | input = Variable(torch.randn(3, 2208, 7, 7)) 82 | output = m(input) 83 | print(output.size()) 84 | -------------------------------------------------------------------------------- /CUB-200-2011.py: -------------------------------------------------------------------------------- 1 | '''PyTorch CUB-200-2011 Training with VGG16 (TRAINED FROM SCRATCH).''' 2 | from __future__ import print_function 3 | import os 4 | # import nni 5 | import time 6 | import torch 7 | import logging 8 | import argparse 9 | import torchvision 10 | import random 11 | import torch.nn as nn 12 | import numpy as np 13 | import torch.optim as optim 14 | import torch.nn.functional as F 15 | from torch.autograd import Variable 16 | import torch.backends.cudnn as cudnn 17 | import torchvision 18 | from my_pooling import my_MaxPool2d,my_AvgPool2d 19 | import torchvision.transforms as transforms 20 | 21 | 22 | logger = logging.getLogger('MC_VGG_224') 23 | 24 | 25 | os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" 26 | 27 | lr = 0.1 28 | nb_epoch = 300 29 | criterion = nn.CrossEntropyLoss() 30 | 31 | #Data 32 | print('==> Preparing data..') 33 | transform_train = transforms.Compose([ 34 | transforms.Scale((224,224)), 35 | transforms.RandomCrop(224, padding=4), 36 | transforms.RandomHorizontalFlip(), 37 | transforms.ToTensor(), 38 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 39 | ]) 40 | 41 | transform_test = transforms.Compose([ 42 | transforms.Scale((224,224)), 43 | transforms.ToTensor(), 44 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 45 | ]) 46 | 47 | 48 | trainset = torchvision.datasets.ImageFolder(root='/home/data/Birds/train', transform=transform_train) 49 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=16, drop_last = True) 50 | 51 | testset = torchvision.datasets.ImageFolder(root='/home/data/Birds/test', transform=transform_test) 52 | testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=True, num_workers=16) 53 | 54 | 55 | print('==> Building model..') 56 | 57 | cfg = { 58 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 59 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 60 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 600, 'M', 512, 512, 600], 61 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 62 | } 63 | 64 | 65 | class VGG(nn.Module): 66 | def __init__(self, vgg_name): 67 | super(VGG, self).__init__() 68 | self.features = self._make_layers(cfg[vgg_name]) 69 | self.classifier = nn.Linear(512, 10) 70 | 71 | def forward(self, x): 72 | out = self.features(x) 73 | out = out.view(out.size(0), -1) 74 | out = self.classifier(out) 75 | return out 76 | 77 | def _make_layers(self, cfg): 78 | layers = [] 79 | in_channels = 3 80 | for x in cfg: 81 | if x == 'M': 82 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 83 | else: 84 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 85 | nn.BatchNorm2d(x), 86 | nn.ReLU(inplace=True)] 87 | in_channels = x 88 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 89 | return nn.Sequential(*layers) 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | def Mask(nb_batch, channels): 99 | 100 | foo = [1] * 2 + [0] * 1 101 | bar = [] 102 | for i in range(200): 103 | random.shuffle(foo) 104 | bar += foo 105 | bar = [bar for i in range(nb_batch)] 106 | bar = np.array(bar).astype("float32") 107 | bar = bar.reshape(nb_batch,200*channels,1,1) 108 | bar = torch.from_numpy(bar) 109 | bar = bar.cuda() 110 | bar = Variable(bar) 111 | return bar 112 | 113 | def supervisor(x,targets,height,cnum): 114 | mask = Mask(x.size(0), cnum) 115 | branch = x 116 | branch = branch.reshape(branch.size(0),branch.size(1), branch.size(2) * branch.size(3)) 117 | branch = F.softmax(branch,2) 118 | branch = branch.reshape(branch.size(0),branch.size(1), x.size(2), x.size(2)) 119 | branch = my_MaxPool2d(kernel_size=(1,cnum), stride=(1,cnum))(branch) 120 | branch = branch.reshape(branch.size(0),branch.size(1), branch.size(2) * branch.size(3)) 121 | loss_2 = 1.0 - 1.0*torch.mean(torch.sum(branch,2))/cnum # set margin = 3.0 122 | 123 | branch_1 = x * mask 124 | 125 | branch_1 = my_MaxPool2d(kernel_size=(1,cnum), stride=(1,cnum))(branch_1) 126 | branch_1 = nn.AvgPool2d(kernel_size=(height,height))(branch_1) 127 | branch_1 = branch_1.view(branch_1.size(0), -1) 128 | 129 | loss_1 = criterion(branch_1, targets) 130 | 131 | return [loss_1, loss_2] 132 | 133 | class model_bn(nn.Module): 134 | def __init__(self, feature_size=512,classes_num=200): 135 | 136 | super(model_bn, self).__init__() 137 | 138 | self.features_1 = nn.Sequential(*list(VGG('VGG16').features.children())[:34]) 139 | self.features_2 = nn.Sequential(*list(VGG('VGG16').features.children())[34:]) 140 | 141 | self.max = nn.MaxPool2d(kernel_size=2, stride=2) 142 | 143 | self.num_ftrs = 600*7*7 144 | self.classifier = nn.Sequential( 145 | nn.BatchNorm1d(self.num_ftrs), 146 | #nn.Dropout(0.5), 147 | nn.Linear(self.num_ftrs, feature_size), 148 | nn.BatchNorm1d(feature_size), 149 | nn.ELU(inplace=True), 150 | #nn.Dropout(0.5), 151 | nn.Linear(feature_size, classes_num), 152 | ) 153 | 154 | def forward(self, x, targets): 155 | 156 | 157 | x = self.features_1(x) 158 | 159 | x = self.features_2(x) 160 | 161 | if self.training: 162 | MC_loss = supervisor(x,targets,height=14,cnum=3) 163 | 164 | x = self.max(x) 165 | x = x.view(x.size(0), -1) 166 | x = self.classifier(x) 167 | loss = criterion(x, targets) 168 | 169 | if self.training: 170 | return x, loss, MC_loss 171 | else: 172 | return x, loss 173 | 174 | 175 | use_cuda = torch.cuda.is_available() 176 | 177 | 178 | 179 | net =model_bn(512, 200) 180 | 181 | if use_cuda: 182 | net.classifier.cuda() 183 | net.features_1.cuda() 184 | net.features_2.cuda() 185 | 186 | net.classifier = torch.nn.DataParallel(net.classifier) 187 | net.features_1 = torch.nn.DataParallel(net.features_1) 188 | net.features_2 = torch.nn.DataParallel(net.features_2) 189 | 190 | cudnn.benchmark = True 191 | 192 | 193 | def train(epoch,net, args, trainloader,optimizer): 194 | print('\nEpoch: %d' % epoch) 195 | net.train() 196 | train_loss = 0 197 | correct = 0 198 | total = 0 199 | idx = 0 200 | 201 | 202 | for batch_idx, (inputs, targets) in enumerate(trainloader): 203 | idx = batch_idx 204 | 205 | inputs, targets = inputs.cuda(), targets.cuda() 206 | optimizer.zero_grad() 207 | inputs, targets = Variable(inputs), Variable(targets) 208 | out, ce_loss, MC_loss = net(inputs, targets) 209 | 210 | loss = ce_loss + args["alpha_1"] * MC_loss[0] + args["beta_1"] * MC_loss[1] 211 | 212 | loss.backward() 213 | optimizer.step() 214 | 215 | 216 | train_loss += loss.item() 217 | 218 | _, predicted = torch.max(out.data, 1) 219 | total += targets.size(0) 220 | correct += predicted.eq(targets.data).cpu().sum().item() 221 | 222 | 223 | 224 | train_acc = 100.*correct/total 225 | train_loss = train_loss/(idx+1) 226 | logging.info('Iteration %d, train_acc = %.5f,train_loss = %.6f' % (epoch, train_acc,train_loss)) 227 | return train_acc, train_loss 228 | 229 | def test(epoch,net,testloader,optimizer): 230 | 231 | net.eval() 232 | test_loss = 0 233 | correct = 0 234 | total = 0 235 | idx = 0 236 | for batch_idx, (inputs, targets) in enumerate(testloader): 237 | with torch.no_grad(): 238 | idx = batch_idx 239 | if use_cuda: 240 | inputs, targets = inputs.cuda(), targets.cuda() 241 | inputs, targets = Variable(inputs), Variable(targets) 242 | out, ce_loss = net(inputs,targets) 243 | 244 | test_loss += ce_loss.item() 245 | _, predicted = torch.max(out.data, 1) 246 | total += targets.size(0) 247 | correct += predicted.eq(targets.data).cpu().sum().item() 248 | 249 | 250 | test_acc = 100.*correct/total 251 | test_loss = test_loss/(idx+1) 252 | logging.info('test, test_acc = %.4f,test_loss = %.4f' % (test_acc,test_loss)) 253 | 254 | return test_acc 255 | 256 | def cosine_anneal_schedule(t): 257 | cos_inner = np.pi * (t % (nb_epoch )) # t - 1 is used when t has 1-based indexing. 258 | cos_inner /= (nb_epoch ) 259 | cos_out = np.cos(cos_inner) + 1 260 | return float( 0.1 / 2 * cos_out) 261 | 262 | 263 | optimizer = optim.SGD([ 264 | {'params': net.classifier.parameters(), 'lr': 0.1}, 265 | {'params': net.features_1.parameters(), 'lr': 0.1}, 266 | {'params': net.features_2.parameters(), 'lr': 0.1}, 267 | 268 | ], 269 | momentum=0.9, weight_decay=5e-4) 270 | 271 | 272 | def get_params(): 273 | # Training settings 274 | parser = argparse.ArgumentParser(description='PyTorch MC2_AutoML Example') 275 | 276 | parser.add_argument('--alpha_1', type=float, default=1.5, metavar='ALPHA', 277 | help='alpha_1 value (default: 2.0)') 278 | parser.add_argument('--beta_1', type=float, default=20.0, metavar='BETA', 279 | help='beta_1 value (default: 20.0)') 280 | 281 | args, _ = parser.parse_known_args() 282 | return args 283 | 284 | if __name__ == '__main__': 285 | try: 286 | args = vars(get_params()) 287 | print(args) 288 | # main(params) 289 | max_val_acc = 0 290 | for epoch in range(1, nb_epoch+1): 291 | if epoch ==150: 292 | lr = 0.01 293 | if epoch ==225: 294 | lr = 0.001 295 | optimizer.param_groups[0]['lr'] = lr 296 | optimizer.param_groups[1]['lr'] = lr 297 | optimizer.param_groups[2]['lr'] = lr 298 | 299 | train(epoch, net, args,trainloader,optimizer) 300 | test_acc = test(epoch, net,testloader,optimizer) 301 | if test_acc >max_val_acc: 302 | max_val_acc = test_acc 303 | 304 | print("max_val_acc", max_val_acc) 305 | 306 | 307 | except Exception as exception: 308 | logger.exception(exception) 309 | raise 310 | 311 | -------------------------------------------------------------------------------- /CUB-200-2011_ResNet18.py: -------------------------------------------------------------------------------- 1 | '''PyTorch CUB-200-2011 Training with ResNet18 (TRAINED FROM SCRATCH). 2 | NOTICE: for baseline, the channel of the final features should keep same with the Vanilla ResNet18''' 3 | from __future__ import print_function 4 | import os 5 | # import nni 6 | import time 7 | import torch 8 | import logging 9 | import argparse 10 | import torchvision 11 | import random 12 | import torch.nn as nn 13 | import numpy as np 14 | import torch.optim as optim 15 | import torch.nn.functional as F 16 | from torch.autograd import Variable 17 | import torch.backends.cudnn as cudnn 18 | import torchvision 19 | from my_pooling import my_MaxPool2d,my_AvgPool2d 20 | import torchvision.transforms as transforms 21 | 22 | 23 | logger = logging.getLogger('MC_ResNet18_224') 24 | 25 | 26 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 27 | 28 | lr = 0.1 29 | nb_epoch = 300 30 | criterion = nn.CrossEntropyLoss() 31 | 32 | #Data 33 | print('==> Preparing data..') 34 | transform_train = transforms.Compose([ 35 | transforms.Scale((224,224)), 36 | transforms.RandomCrop(224, padding=4), 37 | transforms.RandomHorizontalFlip(), 38 | transforms.ToTensor(), 39 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 40 | ]) 41 | 42 | transform_test = transforms.Compose([ 43 | transforms.Scale((224,224)), 44 | transforms.ToTensor(), 45 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 46 | ]) 47 | 48 | 49 | trainset = torchvision.datasets.ImageFolder(root='/home/data/Birds/train', transform=transform_train) 50 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=16, drop_last = True) 51 | 52 | testset = torchvision.datasets.ImageFolder(root='/home/data/Birds/test', transform=transform_test) 53 | testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=True, num_workers=16) 54 | 55 | 56 | 57 | print('==> Building model..') 58 | 59 | # Model 60 | 61 | import torch.nn as nn 62 | import math 63 | import torch.utils.model_zoo as model_zoo 64 | 65 | 66 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 67 | 'resnet152'] 68 | 69 | 70 | model_urls = { 71 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 72 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 73 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 74 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 75 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 76 | } 77 | 78 | 79 | def conv3x3(in_planes, out_planes, stride=1): 80 | "3x3 convolution with padding" 81 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 82 | padding=1, bias=False) 83 | 84 | 85 | class BasicBlock(nn.Module): 86 | expansion = 1 87 | 88 | def __init__(self, inplanes, planes, stride=1, downsample=None): 89 | super(BasicBlock, self).__init__() 90 | self.conv1 = conv3x3(inplanes, planes, stride) 91 | self.bn1 = nn.BatchNorm2d(planes) 92 | self.relu = nn.ReLU(inplace=True) 93 | self.conv2 = conv3x3(planes, planes) 94 | self.bn2 = nn.BatchNorm2d(planes) 95 | self.downsample = downsample 96 | self.stride = stride 97 | 98 | def forward(self, x): 99 | residual = x 100 | 101 | out = self.conv1(x) 102 | out = self.bn1(out) 103 | out = self.relu(out) 104 | 105 | out = self.conv2(out) 106 | out = self.bn2(out) 107 | 108 | if self.downsample is not None: 109 | residual = self.downsample(x) 110 | 111 | out += residual 112 | out = self.relu(out) 113 | 114 | return out 115 | 116 | 117 | class Bottleneck(nn.Module): 118 | expansion = 4 119 | 120 | def __init__(self, inplanes, planes, stride=1, downsample=None): 121 | super(Bottleneck, self).__init__() 122 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 123 | self.bn1 = nn.BatchNorm2d(planes) 124 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 125 | padding=1, bias=False) 126 | self.bn2 = nn.BatchNorm2d(planes) 127 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 128 | self.bn3 = nn.BatchNorm2d(planes * 4) 129 | self.relu = nn.ReLU(inplace=True) 130 | self.downsample = downsample 131 | self.stride = stride 132 | 133 | def forward(self, x): 134 | residual = x 135 | 136 | out = self.conv1(x) 137 | out = self.bn1(out) 138 | out = self.relu(out) 139 | 140 | out = self.conv2(out) 141 | out = self.bn2(out) 142 | out = self.relu(out) 143 | 144 | out = self.conv3(out) 145 | out = self.bn3(out) 146 | 147 | if self.downsample is not None: 148 | residual = self.downsample(x) 149 | 150 | out += residual 151 | out = self.relu(out) 152 | 153 | return out 154 | 155 | 156 | class ResNet(nn.Module): 157 | 158 | def __init__(self, block, layers, num_classes=1000): 159 | self.inplanes = 64 160 | super(ResNet, self).__init__() 161 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 162 | bias=False) 163 | self.bn1 = nn.BatchNorm2d(64) 164 | self.relu = nn.ReLU(inplace=True) 165 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 166 | self.layer1 = self._make_layer(block, 64, layers[0]) 167 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 168 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 169 | self.layer4 = self._make_layer(block, 600, layers[3], stride=1) 170 | self.avgpool = nn.AvgPool2d(7, stride=1) 171 | self.fc = nn.Linear(512 * block.expansion, num_classes) 172 | 173 | for m in self.modules(): 174 | if isinstance(m, nn.Conv2d): 175 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 176 | m.weight.data.normal_(0, math.sqrt(2. / n)) 177 | elif isinstance(m, nn.BatchNorm2d): 178 | m.weight.data.fill_(1) 179 | m.bias.data.zero_() 180 | 181 | def _make_layer(self, block, planes, blocks, stride=1): 182 | downsample = None 183 | if stride != 1 or self.inplanes != planes * block.expansion: 184 | downsample = nn.Sequential( 185 | nn.Conv2d(self.inplanes, planes * block.expansion, 186 | kernel_size=1, stride=stride, bias=False), 187 | nn.BatchNorm2d(planes * block.expansion), 188 | ) 189 | 190 | layers = [] 191 | layers.append(block(self.inplanes, planes, stride, downsample)) 192 | self.inplanes = planes * block.expansion 193 | for i in range(1, blocks): 194 | layers.append(block(self.inplanes, planes)) 195 | 196 | return nn.Sequential(*layers) 197 | 198 | 199 | 200 | def resnet18(pretrained=False, **kwargs): 201 | """Constructs a ResNet-18 model. 202 | 203 | Args: 204 | pretrained (bool): If True, returns a model pre-trained on ImageNet 205 | """ 206 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 207 | if pretrained: 208 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 209 | return model 210 | 211 | 212 | net = resnet18(pretrained=False) 213 | 214 | 215 | 216 | def Mask(nb_batch, channels): 217 | 218 | foo = [1] * 2 + [0] * 1 219 | bar = [] 220 | for i in range(200): 221 | random.shuffle(foo) 222 | bar += foo 223 | bar = [bar for i in range(nb_batch)] 224 | bar = np.array(bar).astype("float32") 225 | bar = bar.reshape(nb_batch,200*channels,1,1) 226 | bar = torch.from_numpy(bar) 227 | bar = bar.cuda() 228 | bar = Variable(bar) 229 | return bar 230 | 231 | def supervisor(x,targets,height,cnum): 232 | mask = Mask(x.size(0), cnum) 233 | branch = x 234 | branch = branch.reshape(branch.size(0),branch.size(1), branch.size(2) * branch.size(3)) 235 | branch = F.softmax(branch,2) 236 | branch = branch.reshape(branch.size(0),branch.size(1), x.size(2), x.size(2)) 237 | branch = my_MaxPool2d(kernel_size=(1,cnum), stride=(1,cnum))(branch) 238 | branch = branch.reshape(branch.size(0),branch.size(1), branch.size(2) * branch.size(3)) 239 | loss_2 = 1.0 - 1.0*torch.mean(torch.sum(branch,2))/cnum # set margin = 3.0 240 | 241 | branch_1 = x * mask 242 | 243 | branch_1 = my_MaxPool2d(kernel_size=(1,cnum), stride=(1,cnum))(branch_1) 244 | branch_1 = nn.AvgPool2d(kernel_size=(height,height))(branch_1) 245 | branch_1 = branch_1.view(branch_1.size(0), -1) 246 | 247 | loss_1 = criterion(branch_1, targets) 248 | 249 | return [loss_1, loss_2] 250 | 251 | class model_bn(nn.Module): 252 | def __init__(self, feature_size=512,classes_num=200): 253 | 254 | super(model_bn, self).__init__() 255 | 256 | self.features = nn.Sequential(*list(net.children())[:-2]) 257 | 258 | self.max = nn.MaxPool2d(kernel_size=14, stride=14) 259 | 260 | self.num_ftrs = 600*1*1 261 | self.classifier = nn.Sequential( 262 | nn.BatchNorm1d(self.num_ftrs), 263 | #nn.Dropout(0.5), 264 | nn.Linear(self.num_ftrs, feature_size), 265 | nn.BatchNorm1d(feature_size), 266 | nn.ELU(inplace=True), 267 | #nn.Dropout(0.5), 268 | nn.Linear(feature_size, classes_num), 269 | ) 270 | 271 | def forward(self, x, targets): 272 | 273 | 274 | 275 | x = self.features(x) 276 | 277 | if self.training: 278 | MC_loss = supervisor(x,targets,height=14,cnum=3) 279 | 280 | x = self.max(x) 281 | x = x.view(x.size(0), -1) 282 | x = self.classifier(x) 283 | loss = criterion(x, targets) 284 | 285 | if self.training: 286 | return x, loss, MC_loss 287 | else: 288 | return x, loss 289 | 290 | 291 | use_cuda = torch.cuda.is_available() 292 | 293 | 294 | 295 | net =model_bn(512, 200) 296 | 297 | if use_cuda: 298 | net.classifier.cuda() 299 | net.features.cuda() 300 | 301 | 302 | net.classifier = torch.nn.DataParallel(net.classifier) 303 | net.features = torch.nn.DataParallel(net.features) 304 | 305 | 306 | cudnn.benchmark = True 307 | 308 | 309 | def train(epoch,net, args, trainloader,optimizer): 310 | print('\nEpoch: %d' % epoch) 311 | net.train() 312 | train_loss = 0 313 | correct = 0 314 | total = 0 315 | idx = 0 316 | 317 | 318 | for batch_idx, (inputs, targets) in enumerate(trainloader): 319 | idx = batch_idx 320 | 321 | inputs, targets = inputs.cuda(), targets.cuda() 322 | optimizer.zero_grad() 323 | inputs, targets = Variable(inputs), Variable(targets) 324 | out, ce_loss, MC_loss = net(inputs, targets) 325 | 326 | loss = ce_loss + args["alpha_1"] * MC_loss[0] + args["beta_1"] * MC_loss[1] 327 | 328 | loss.backward() 329 | optimizer.step() 330 | 331 | 332 | train_loss += loss.item() 333 | 334 | _, predicted = torch.max(out.data, 1) 335 | total += targets.size(0) 336 | correct += predicted.eq(targets.data).cpu().sum().item() 337 | 338 | 339 | 340 | train_acc = 100.*correct/total 341 | train_loss = train_loss/(idx+1) 342 | logging.info('Iteration %d, train_acc = %.5f,train_loss = %.6f' % (epoch, train_acc,train_loss)) 343 | return train_acc, train_loss 344 | 345 | def test(epoch,net,testloader,optimizer): 346 | 347 | net.eval() 348 | test_loss = 0 349 | correct = 0 350 | total = 0 351 | idx = 0 352 | for batch_idx, (inputs, targets) in enumerate(testloader): 353 | with torch.no_grad(): 354 | idx = batch_idx 355 | if use_cuda: 356 | inputs, targets = inputs.cuda(), targets.cuda() 357 | inputs, targets = Variable(inputs), Variable(targets) 358 | out, ce_loss = net(inputs,targets) 359 | 360 | test_loss += ce_loss.item() 361 | _, predicted = torch.max(out.data, 1) 362 | total += targets.size(0) 363 | correct += predicted.eq(targets.data).cpu().sum().item() 364 | 365 | 366 | test_acc = 100.*correct/total 367 | test_loss = test_loss/(idx+1) 368 | logging.info('test, test_acc = %.4f,test_loss = %.4f' % (test_acc,test_loss)) 369 | 370 | return test_acc 371 | 372 | def cosine_anneal_schedule(t): 373 | cos_inner = np.pi * (t % (nb_epoch )) # t - 1 is used when t has 1-based indexing. 374 | cos_inner /= (nb_epoch ) 375 | cos_out = np.cos(cos_inner) + 1 376 | return float( 0.1 / 2 * cos_out) 377 | 378 | 379 | optimizer = optim.SGD([ 380 | {'params': net.classifier.parameters(), 'lr': 0.1}, 381 | {'params': net.features.parameters(), 'lr': 0.1}, 382 | 383 | ], 384 | momentum=0.9, weight_decay=5e-4) 385 | 386 | 387 | def get_params(): 388 | # Training settings 389 | parser = argparse.ArgumentParser(description='PyTorch MC2_AutoML Example') 390 | 391 | parser.add_argument('--alpha_1', type=float, default=1.5, metavar='ALPHA', 392 | help='alpha_1 value (default: 2.0)') 393 | parser.add_argument('--beta_1', type=float, default=20.0, metavar='BETA', 394 | help='beta_1 value (default: 20.0)') 395 | 396 | args, _ = parser.parse_known_args() 397 | return args 398 | 399 | if __name__ == '__main__': 400 | try: 401 | args = vars(get_params()) 402 | print(args) 403 | # main(params) 404 | max_val_acc = 0 405 | for epoch in range(1, nb_epoch+1): 406 | if epoch ==150: 407 | lr = 0.01 408 | if epoch ==225: 409 | lr = 0.001 410 | optimizer.param_groups[0]['lr'] = lr 411 | optimizer.param_groups[1]['lr'] = lr 412 | 413 | train(epoch, net, args,trainloader,optimizer) 414 | test_acc = test(epoch, net,testloader,optimizer) 415 | if test_acc >max_val_acc: 416 | max_val_acc = test_acc 417 | 418 | print("max_val_acc", max_val_acc) 419 | 420 | 421 | except Exception as exception: 422 | logger.exception(exception) 423 | raise 424 | 425 | --------------------------------------------------------------------------------