├── .gitignore ├── LICENSE ├── NOTICE ├── README.md ├── img1.PNG ├── pyramidnet.py ├── resnet.py ├── test.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .idea/ 3 | runs/ 4 | train_nsml.py 5 | requirements.txt 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019-present NAVER Corp. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | CutMix 2 | Copyright (c) 2019-present NAVER Corp. 3 | 4 | This project contains subcomponents with separate copyright notices and license terms. 5 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. 6 | 7 | ======================================================================= 8 | pytorch/vision from https://github.com/pytorch/vision 9 | ======================================================================= 10 | 11 | BSD 3-Clause License 12 | 13 | Copyright (c) Soumith Chintala 2016, 14 | All rights reserved. 15 | 16 | Redistribution and use in source and binary forms, with or without 17 | modification, are permitted provided that the following conditions are met: 18 | 19 | * Redistributions of source code must retain the above copyright notice, this 20 | list of conditions and the following disclaimer. 21 | 22 | * Redistributions in binary form must reproduce the above copyright notice, 23 | this list of conditions and the following disclaimer in the documentation 24 | and/or other materials provided with the distribution. 25 | 26 | * Neither the name of the copyright holder nor the names of its 27 | contributors may be used to endorse or promote products derived from 28 | this software without specific prior written permission. 29 | 30 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 31 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 32 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 33 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 34 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 35 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 36 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 37 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 38 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 39 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 40 | 41 | ======================================================================= 42 | dyhan0920/PyramidNet-PyTorch from https://github.com/dyhan0920/PyramidNet-PyTorch 43 | ======================================================================= 44 | 45 | MIT License 46 | 47 | Copyright (c) 2019 Dongyoon Han 48 | 49 | Permission is hereby granted, free of charge, to any person obtaining a copy 50 | of this software and associated documentation files (the "Software"), to deal 51 | in the Software without restriction, including without limitation the rights 52 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 53 | copies of the Software, and to permit persons to whom the Software is 54 | furnished to do so, subject to the following conditions: 55 | 56 | The above copyright notice and this permission notice shall be included in all 57 | copies or substantial portions of the Software. 58 | 59 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 60 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 61 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 62 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 63 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 64 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 65 | SOFTWARE. 66 | 67 | ======================================================================= 68 | eladhoffer/convNet.pytorch from https://github.com/eladhoffer/convNet.pytorch 69 | ======================================================================= 70 | 71 | MIT License 72 | 73 | Copyright (c) 2017 Elad Hoffer 74 | 75 | Permission is hereby granted, free of charge, to any person obtaining a copy 76 | of this software and associated documentation files (the "Software"), to deal 77 | in the Software without restriction, including without limitation the rights 78 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 79 | copies of the Software, and to permit persons to whom the Software is 80 | furnished to do so, subject to the following conditions: 81 | 82 | The above copyright notice and this permission notice shall be included in all 83 | copies or substantial portions of the Software. 84 | 85 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 86 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 87 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 88 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 89 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 90 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 91 | SOFTWARE. 92 | 93 | ===== 94 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Accepted at ICCV 2019 (oral talk) !! 2 | 3 | ## CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features 4 | 5 | Official Pytorch implementation of CutMix regularizer | [Paper](https://arxiv.org/abs/1905.04899) | [Pretrained Models](#experiments) 6 | 7 | **[Sangdoo Yun](mailto:sangdoo.yun@navercorp.com), Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo.** 8 | 9 | Clova AI Research, NAVER Corp. 10 | 11 | Our implementation is based on these repositories: 12 | - [PyTorch ImageNet Example](https://github.com/pytorch/examples/tree/master/imagenet) 13 | - [PyramidNet-PyTorch](https://github.com/dyhan0920/PyramidNet-PyTorch) 14 | 15 | 16 | ### Abstract 17 | Regional dropout strategies have been proposed to enhance the performance of convolutional neural network classifiers. 18 | They have proved to be effective for guiding the model to attend on less discriminative parts of objects 19 | (e.g. leg as opposed to head of a person), thereby letting the network generalize better and have better object localization capabilities. 20 | On the other hand, current methods for regional dropout removes informative pixels on training images by overlaying a patch of either black pixels or random noise. 21 | Such removal is not desirable because it leads to information loss and inefficiency during training. 22 | We therefore propose the **CutMix** augmentation strategy: patches are cut and pasted among training images where the ground truth labels are also mixed proportionally to the area of the patches. 23 | By making efficient use of training pixels and retaining the regularization effect of regional dropout, CutMix consistently outperforms the state-of-the-art augmentation strategies on CIFAR and ImageNet classification tasks, as well as on the ImageNet weakly-supervised localization task. 24 | Moreover, unlike previous augmentation methods, our CutMix-trained ImageNet classifier, when used as a pretrained model, results in consistent performance gains in Pascal detection and MS-COCO image captioning benchmarks. 25 | We also show that CutMix improves the model robustness against input corruptions and its out-of-distribution detection performances. 26 | 27 | 28 | ### Overview of the results of Mixup, Cutout, and CutMix. 29 | 30 | teaser 31 | 32 | ## Updates 33 | **23 May, 2019**: Initial upload 34 | 35 | ## Getting Started 36 | ### Requirements 37 | - Python3 38 | - PyTorch (> 1.0) 39 | - torchvision (> 0.2) 40 | - NumPy 41 | 42 | ### Train Examples 43 | - CIFAR-100: We used 2 GPUs to train CIFAR-100. 44 | ``` 45 | python train.py \ 46 | --net_type pyramidnet \ 47 | --dataset cifar100 \ 48 | --depth 200 \ 49 | --alpha 240 \ 50 | --batch_size 64 \ 51 | --lr 0.25 \ 52 | --expname PyraNet200 \ 53 | --epochs 300 \ 54 | --beta 1.0 \ 55 | --cutmix_prob 0.5 \ 56 | --no-verbose 57 | ``` 58 | - ImageNet: We used 4 GPUs to train ImageNet. 59 | ``` 60 | python train.py \ 61 | --net_type resnet \ 62 | --dataset imagenet \ 63 | --batch_size 256 \ 64 | --lr 0.1 \ 65 | --depth 50 \ 66 | --epochs 300 \ 67 | --expname ResNet50 \ 68 | -j 40 \ 69 | --beta 1.0 \ 70 | --cutmix_prob 1.0 \ 71 | --no-verbose 72 | ``` 73 | 74 | ### Test Examples using Pretrained model 75 | - Download [CutMix-pretrained PyramidNet200 (top-1 error: 14.23)](https://www.dropbox.com/sh/o68qbvayptt2rz5/AACy3o779BxoRqw6_GQf_QFQa?dl=0) 76 | ``` 77 | python test.py \ 78 | --net_type pyramidnet \ 79 | --dataset cifar100 \ 80 | --batch_size 64 \ 81 | --depth 200 \ 82 | --alpha 240 \ 83 | --pretrained /set/your/model/path/model_best.pth.tar 84 | ``` 85 | - Download [CutMix-pretrained ResNet50 (top-1 error: 21.40)](https://www.dropbox.com/sh/w8dvfgdc3eirivf/AABnGcTO9wao9xVGWwqsXRala?dl=0) 86 | ``` 87 | python test.py \ 88 | --net_type resnet \ 89 | --dataset imagenet \ 90 | --batch_size 64 \ 91 | --depth 50 \ 92 | --pretrained /set/your/model/path/model_best.pth.tar 93 | ``` 94 | 95 |

Experimental Results and Pretrained Models

96 | 97 | - PyramidNet-200 pretrained on CIFAR-100 dataset: 98 | 99 | Method | Top-1 Error | Model file 100 | -- | -- | -- 101 | PyramidNet-200 [[CVPR'17](https://arxiv.org/abs/1610.02915)] (baseline) | 16.45 | [model](https://www.dropbox.com/sh/6rfew3lr761jq6c/AADrdQOXNx5tWmgOSnAw9NEVa?dl=0) 102 | PyramidNet-200 + **CutMix** | **14.23** | [model](https://www.dropbox.com/sh/o68qbvayptt2rz5/AACy3o779BxoRqw6_GQf_QFQa?dl=0) 103 | PyramidNet-200 + Shakedrop [[arXiv'18](https://arxiv.org/abs/1802.02375)] + **CutMix** | **13.81** | - 104 | PyramidNet-200 + Mixup [[ICLR'18](https://arxiv.org/abs/1710.09412)] | 15.63 | [model](https://www.dropbox.com/sh/g55jnsv62v0n59s/AAC9LPg-LjlnBn4ttKs6vr7Ka?dl=0) 105 | PyramidNet-200 + Manifold Mixup [[ICML'19](https://arxiv.org/abs/1806.05236)] | 16.14 | [model](https://www.dropbox.com/sh/nngw7hhk1e8msbr/AABkdCsP0ABnQJDBX7LQVj4la?dl=0) 106 | PyramidNet-200 + Cutout [[arXiv'17](https://arxiv.org/abs/1708.04552)] | 16.53 | [model](https://www.dropbox.com/sh/ajjz4q8c8t6qva9/AAAeBGb2Q4TnJMW0JAzeVSpfa?dl=0) 107 | PyramidNet-200 + DropBlock [[NeurIPS'18](https://arxiv.org/abs/1810.12890)] | 15.73 | [model](https://www.dropbox.com/sh/vefjo960gyrsx2i/AACYA5wOJ_yroNjIjdsN1Dz2a?dl=0) 108 | PyramidNet-200 + Cutout + Labelsmoothing | 15.61 | [model](https://www.dropbox.com/sh/1mur0kjcfxdn7jn/AADmghqrj0dXAG0qY1v3Csb6a?dl=0) 109 | PyramidNet-200 + DropBlock + Labelsmoothing | 15.16 | [model](https://www.dropbox.com/sh/n1dn6ggyxjcoogc/AADpSSNzvaraSCqWtHBE0qMca?dl=0) 110 | PyramidNet-200 + Cutout + Mixup | 15.46 | [model](https://www.dropbox.com/sh/5run1sx8oy0v9oi/AACiR_wEBQVp2HMZFx6lGl3ka?dl=0) 111 | 112 | 113 | - ResNet models pretrained on ImageNet dataset: 114 | 115 | Method | Top-1 Error | Model file 116 | -- | -- | -- 117 | ResNet-50 [[CVPR'16](https://arxiv.org/abs/1512.03385)] (baseline) | 23.68 | [model](https://www.dropbox.com/sh/phwbbrtadrclpnx/AAA9QUW9G_xvBdI-mDiIzP_Ha?dl=0) 118 | ResNet-50 + **CutMix** | **21.40** | [model](https://www.dropbox.com/sh/w8dvfgdc3eirivf/AABnGcTO9wao9xVGWwqsXRala?dl=0) 119 | ResNet-50 + **Feature CutMix** | **21.80** | [model](https://www.dropbox.com/sh/zj1wptsg0hwqf0k/AABRNzvjFmIS7_vOEQkqb6T4a?dl=0) 120 | ResNet-50 + Mixup [[ICLR'18](https://arxiv.org/abs/1710.09412)] | 22.58 | [model](https://www.dropbox.com/sh/g64c8bda61n12if/AACyaTZnku_Sgibc9UvOSblNa?dl=0) 121 | ResNet-50 + Manifold Mixup [[ICML'19](https://arxiv.org/abs/1806.05236)] | 22.50 | [model](https://www.dropbox.com/sh/bjardjje11pti0g/AABFGW0gNrNE8o8TqUf4-SYSa?dl=0) 122 | ResNet-50 + Cutout [[arXiv'17](https://arxiv.org/abs/1708.04552)] | 22.93 | [model](https://www.dropbox.com/sh/ln8zk2z7zt2h1en/AAA7z8xTBlzz7Ofbd5L7oTnTa?dl=0) 123 | ResNet-50 + AutoAugment [[CVPR'19](https://arxiv.org/abs/1805.09501)] | 22.40* | - 124 | ResNet-50 + DropBlock [[NeurIPS'18](https://arxiv.org/abs/1810.12890)] | 21.87* | - 125 | ResNet-101 + **CutMix** | **20.17** | [model](https://www.dropbox.com/sh/1z4xnp9nwdmpzb5/AACQX4KU8XkTN0JSTfjkCktNa?dl=0) 126 | ResNet-152 + **CutMix** | **19.20** | [model](https://www.dropbox.com/s/6vq1mzy27z8qxko/resnet152_cutmix_acc_80_80.pth?dl=0) 127 | ResNeXt-101 (32x4d) + **CutMix** | **19.47** | [model](https://www.dropbox.com/s/maysvgopsi17qi0/resnext_cutmix.pth.tar?dl=0) 128 | 129 | \* denotes results reported in the original papers 130 | 131 | ## Transfer Learning Results 132 | 133 | Backbone | ImageNet Cls (%) | ImageNet Loc (%) | CUB200 Loc (%) | Detection (SSD) (mAP) | Detection (Faster-RCNN) (mAP) | Image Captioning (BLEU-4) 134 | -- | -- | -- | -- | -- | -- | -- 135 | ResNet50 | 23.68 | 46.3 | 49.41 | 76.7 | 75.6 | 22.9 136 | ResNet50+Mixup | 22.58 | 45.84 | 49.3 | 76.6 | 73.9 | 23.2 137 | ResNet50+Cutout | 22.93 | 46.69 | 52.78 | 76.8 | 75 | 24.0 138 | ResNet50+**CutMix** | **21.60** | **46.25** | **54.81** | **77.6** | **76.7** | **24.9** 139 | 140 | 141 | ## Third-party Implementations 142 | - [Pytorch-CutMix](https://github.com/hysts/pytorch_cutmix) by @hysts 143 | - [TensorFlow-CutMix](https://github.com/jis478/Tensorflow/tree/master/TF2.0/Cutmix) by @jis478 144 | 145 | ## Citation 146 | ``` 147 | @inproceedings{yun2019cutmix, 148 | title={CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features}, 149 | author={Yun, Sangdoo and Han, Dongyoon and Oh, Seong Joon and Chun, Sanghyuk and Choe, Junsuk and Yoo, Youngjoon}, 150 | booktitle = {International Conference on Computer Vision (ICCV)}, 151 | year={2019}, 152 | pubstate={published}, 153 | tppubtype={inproceedings} 154 | } 155 | ``` 156 | 157 | ## License 158 | ``` 159 | Copyright (c) 2019-present NAVER Corp. 160 | 161 | Permission is hereby granted, free of charge, to any person obtaining a copy 162 | of this software and associated documentation files (the "Software"), to deal 163 | in the Software without restriction, including without limitation the rights 164 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 165 | copies of the Software, and to permit persons to whom the Software is 166 | furnished to do so, subject to the following conditions: 167 | 168 | The above copyright notice and this permission notice shall be included in 169 | all copies or substantial portions of the Software. 170 | 171 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 172 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 173 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 174 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 175 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 176 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 177 | THE SOFTWARE. 178 | ``` 179 | -------------------------------------------------------------------------------- /img1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/CutMix-PyTorch/2d8eb68faff7fe4962776ad51d175c3b01a25734/img1.PNG -------------------------------------------------------------------------------- /pyramidnet.py: -------------------------------------------------------------------------------- 1 | # Original code: https://github.com/dyhan0920/PyramidNet-PyTorch/blob/master/PyramidNet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | "3x3 convolution with padding" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | outchannel_ratio = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.bn1 = nn.BatchNorm2d(inplanes) 19 | self.conv1 = conv3x3(inplanes, planes, stride) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn3 = nn.BatchNorm2d(planes) 23 | self.relu = nn.ReLU(inplace=True) 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x): 28 | 29 | out = self.bn1(x) 30 | out = self.conv1(out) 31 | out = self.bn2(out) 32 | out = self.relu(out) 33 | out = self.conv2(out) 34 | out = self.bn3(out) 35 | if self.downsample is not None: 36 | shortcut = self.downsample(x) 37 | featuremap_size = shortcut.size()[2:4] 38 | else: 39 | shortcut = x 40 | featuremap_size = out.size()[2:4] 41 | 42 | batch_size = out.size()[0] 43 | residual_channel = out.size()[1] 44 | shortcut_channel = shortcut.size()[1] 45 | 46 | if residual_channel != shortcut_channel: 47 | padding = torch.autograd.Variable(torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], featuremap_size[1]).fill_(0)) 48 | out += torch.cat((shortcut, padding), 1) 49 | else: 50 | out += shortcut 51 | 52 | return out 53 | 54 | 55 | class Bottleneck(nn.Module): 56 | outchannel_ratio = 4 57 | 58 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16): 59 | super(Bottleneck, self).__init__() 60 | self.bn1 = nn.BatchNorm2d(inplanes) 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn2 = nn.BatchNorm2d(planes) 63 | self.conv2 = nn.Conv2d(planes, (planes), kernel_size=3, stride=stride, padding=1, bias=False, groups=1) 64 | self.bn3 = nn.BatchNorm2d((planes)) 65 | self.conv3 = nn.Conv2d((planes), planes * Bottleneck.outchannel_ratio, kernel_size=1, bias=False) 66 | self.bn4 = nn.BatchNorm2d(planes * Bottleneck.outchannel_ratio) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | 74 | out = self.bn1(x) 75 | out = self.conv1(out) 76 | 77 | out = self.bn2(out) 78 | out = self.relu(out) 79 | out = self.conv2(out) 80 | 81 | out = self.bn3(out) 82 | out = self.relu(out) 83 | out = self.conv3(out) 84 | 85 | out = self.bn4(out) 86 | if self.downsample is not None: 87 | shortcut = self.downsample(x) 88 | featuremap_size = shortcut.size()[2:4] 89 | else: 90 | shortcut = x 91 | featuremap_size = out.size()[2:4] 92 | 93 | batch_size = out.size()[0] 94 | residual_channel = out.size()[1] 95 | shortcut_channel = shortcut.size()[1] 96 | 97 | if residual_channel != shortcut_channel: 98 | padding = torch.autograd.Variable(torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], featuremap_size[1]).fill_(0)) 99 | out += torch.cat((shortcut, padding), 1) 100 | else: 101 | out += shortcut 102 | 103 | return out 104 | 105 | 106 | class PyramidNet(nn.Module): 107 | 108 | def __init__(self, dataset, depth, alpha, num_classes, bottleneck=False): 109 | super(PyramidNet, self).__init__() 110 | self.dataset = dataset 111 | if self.dataset.startswith('cifar'): 112 | self.inplanes = 16 113 | if bottleneck == True: 114 | n = int((depth - 2) / 9) 115 | block = Bottleneck 116 | else: 117 | n = int((depth - 2) / 6) 118 | block = BasicBlock 119 | 120 | self.addrate = alpha / (3*n*1.0) 121 | 122 | self.input_featuremap_dim = self.inplanes 123 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=3, stride=1, padding=1, bias=False) 124 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim) 125 | 126 | self.featuremap_dim = self.input_featuremap_dim 127 | self.layer1 = self.pyramidal_make_layer(block, n) 128 | self.layer2 = self.pyramidal_make_layer(block, n, stride=2) 129 | self.layer3 = self.pyramidal_make_layer(block, n, stride=2) 130 | 131 | self.final_featuremap_dim = self.input_featuremap_dim 132 | self.bn_final= nn.BatchNorm2d(self.final_featuremap_dim) 133 | self.relu_final = nn.ReLU(inplace=True) 134 | self.avgpool = nn.AvgPool2d(8) 135 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes) 136 | 137 | elif dataset == 'imagenet': 138 | blocks ={18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck} 139 | layers ={18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]} 140 | 141 | if layers.get(depth) is None: 142 | if bottleneck == True: 143 | blocks[depth] = Bottleneck 144 | temp_cfg = int((depth-2)/12) 145 | else: 146 | blocks[depth] = BasicBlock 147 | temp_cfg = int((depth-2)/8) 148 | 149 | layers[depth]= [temp_cfg, temp_cfg, temp_cfg, temp_cfg] 150 | print('=> the layer configuration for each stage is set to', layers[depth]) 151 | 152 | self.inplanes = 64 153 | self.addrate = alpha / (sum(layers[depth])*1.0) 154 | 155 | self.input_featuremap_dim = self.inplanes 156 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=7, stride=2, padding=3, bias=False) 157 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim) 158 | self.relu = nn.ReLU(inplace=True) 159 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 160 | 161 | self.featuremap_dim = self.input_featuremap_dim 162 | self.layer1 = self.pyramidal_make_layer(blocks[depth], layers[depth][0]) 163 | self.layer2 = self.pyramidal_make_layer(blocks[depth], layers[depth][1], stride=2) 164 | self.layer3 = self.pyramidal_make_layer(blocks[depth], layers[depth][2], stride=2) 165 | self.layer4 = self.pyramidal_make_layer(blocks[depth], layers[depth][3], stride=2) 166 | 167 | self.final_featuremap_dim = self.input_featuremap_dim 168 | self.bn_final= nn.BatchNorm2d(self.final_featuremap_dim) 169 | self.relu_final = nn.ReLU(inplace=True) 170 | self.avgpool = nn.AvgPool2d(7) 171 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes) 172 | 173 | for m in self.modules(): 174 | if isinstance(m, nn.Conv2d): 175 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 176 | m.weight.data.normal_(0, math.sqrt(2. / n)) 177 | elif isinstance(m, nn.BatchNorm2d): 178 | m.weight.data.fill_(1) 179 | m.bias.data.zero_() 180 | 181 | def pyramidal_make_layer(self, block, block_depth, stride=1): 182 | downsample = None 183 | if stride != 1: # or self.inplanes != int(round(featuremap_dim_1st)) * block.outchannel_ratio: 184 | downsample = nn.AvgPool2d((2,2), stride = (2, 2), ceil_mode=True) 185 | 186 | layers = [] 187 | self.featuremap_dim = self.featuremap_dim + self.addrate 188 | layers.append(block(self.input_featuremap_dim, int(round(self.featuremap_dim)), stride, downsample)) 189 | for i in range(1, block_depth): 190 | temp_featuremap_dim = self.featuremap_dim + self.addrate 191 | layers.append(block(int(round(self.featuremap_dim)) * block.outchannel_ratio, int(round(temp_featuremap_dim)), 1)) 192 | self.featuremap_dim = temp_featuremap_dim 193 | self.input_featuremap_dim = int(round(self.featuremap_dim)) * block.outchannel_ratio 194 | 195 | return nn.Sequential(*layers) 196 | 197 | def forward(self, x): 198 | if self.dataset == 'cifar10' or self.dataset == 'cifar100': 199 | x = self.conv1(x) 200 | x = self.bn1(x) 201 | 202 | x = self.layer1(x) 203 | x = self.layer2(x) 204 | x = self.layer3(x) 205 | 206 | x = self.bn_final(x) 207 | x = self.relu_final(x) 208 | x = self.avgpool(x) 209 | x = x.view(x.size(0), -1) 210 | x = self.fc(x) 211 | 212 | elif self.dataset == 'imagenet': 213 | x = self.conv1(x) 214 | x = self.bn1(x) 215 | x = self.relu(x) 216 | x = self.maxpool(x) 217 | 218 | x = self.layer1(x) 219 | x = self.layer2(x) 220 | x = self.layer3(x) 221 | x = self.layer4(x) 222 | 223 | x = self.bn_final(x) 224 | x = self.relu_final(x) 225 | x = self.avgpool(x) 226 | x = x.view(x.size(0), -1) 227 | x = self.fc(x) 228 | 229 | return x 230 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | # Original code: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 2 | 3 | import torch.nn as nn 4 | import math 5 | 6 | def conv3x3(in_planes, out_planes, stride=1): 7 | "3x3 convolution with padding" 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 9 | padding=1, bias=False) 10 | 11 | 12 | class BasicBlock(nn.Module): 13 | expansion = 1 14 | 15 | def __init__(self, inplanes, planes, stride=1, downsample=None): 16 | super(BasicBlock, self).__init__() 17 | self.conv1 = conv3x3(inplanes, planes, stride) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.conv2 = conv3x3(planes, planes) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.relu = nn.ReLU(inplace=True) 22 | 23 | self.downsample = downsample 24 | self.stride = stride 25 | 26 | def forward(self, x): 27 | residual = x 28 | 29 | out = self.conv1(x) 30 | out = self.bn1(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv2(out) 34 | out = self.bn2(out) 35 | 36 | if self.downsample is not None: 37 | residual = self.downsample(x) 38 | 39 | out += residual 40 | out = self.relu(out) 41 | 42 | return out 43 | 44 | 45 | class Bottleneck(nn.Module): 46 | expansion = 4 47 | 48 | def __init__(self, inplanes, planes, stride=1, downsample=None): 49 | super(Bottleneck, self).__init__() 50 | 51 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 52 | self.bn1 = nn.BatchNorm2d(planes) 53 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 54 | self.bn2 = nn.BatchNorm2d(planes) 55 | self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion, kernel_size=1, bias=False) 56 | self.bn3 = nn.BatchNorm2d(planes * Bottleneck.expansion) 57 | self.relu = nn.ReLU(inplace=True) 58 | 59 | self.downsample = downsample 60 | self.stride = stride 61 | 62 | def forward(self, x): 63 | residual = x 64 | 65 | out = self.conv1(x) 66 | out = self.bn1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.bn2(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv3(out) 74 | out = self.bn3(out) 75 | if self.downsample is not None: 76 | residual = self.downsample(x) 77 | 78 | out += residual 79 | out = self.relu(out) 80 | 81 | return out 82 | 83 | class ResNet(nn.Module): 84 | def __init__(self, dataset, depth, num_classes, bottleneck=False): 85 | super(ResNet, self).__init__() 86 | self.dataset = dataset 87 | if self.dataset.startswith('cifar'): 88 | self.inplanes = 16 89 | print(bottleneck) 90 | if bottleneck == True: 91 | n = int((depth - 2) / 9) 92 | block = Bottleneck 93 | else: 94 | n = int((depth - 2) / 6) 95 | block = BasicBlock 96 | 97 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 98 | self.bn1 = nn.BatchNorm2d(self.inplanes) 99 | self.relu = nn.ReLU(inplace=True) 100 | self.layer1 = self._make_layer(block, 16, n) 101 | self.layer2 = self._make_layer(block, 32, n, stride=2) 102 | self.layer3 = self._make_layer(block, 64, n, stride=2) 103 | self.avgpool = nn.AvgPool2d(8) 104 | self.fc = nn.Linear(64 * block.expansion, num_classes) 105 | 106 | elif dataset == 'imagenet': 107 | blocks ={18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck} 108 | layers ={18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]} 109 | assert layers[depth], 'invalid detph for ResNet (depth should be one of 18, 34, 50, 101, 152, and 200)' 110 | 111 | self.inplanes = 64 112 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 113 | self.bn1 = nn.BatchNorm2d(64) 114 | self.relu = nn.ReLU(inplace=True) 115 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 116 | self.layer1 = self._make_layer(blocks[depth], 64, layers[depth][0]) 117 | self.layer2 = self._make_layer(blocks[depth], 128, layers[depth][1], stride=2) 118 | self.layer3 = self._make_layer(blocks[depth], 256, layers[depth][2], stride=2) 119 | self.layer4 = self._make_layer(blocks[depth], 512, layers[depth][3], stride=2) 120 | self.avgpool = nn.AvgPool2d(7) 121 | self.fc = nn.Linear(512 * blocks[depth].expansion, num_classes) 122 | 123 | for m in self.modules(): 124 | if isinstance(m, nn.Conv2d): 125 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 126 | m.weight.data.normal_(0, math.sqrt(2. / n)) 127 | elif isinstance(m, nn.BatchNorm2d): 128 | m.weight.data.fill_(1) 129 | m.bias.data.zero_() 130 | 131 | def _make_layer(self, block, planes, blocks, stride=1): 132 | downsample = None 133 | if stride != 1 or self.inplanes != planes * block.expansion: 134 | downsample = nn.Sequential( 135 | nn.Conv2d(self.inplanes, planes * block.expansion, 136 | kernel_size=1, stride=stride, bias=False), 137 | nn.BatchNorm2d(planes * block.expansion), 138 | ) 139 | 140 | layers = [] 141 | layers.append(block(self.inplanes, planes, stride, downsample)) 142 | self.inplanes = planes * block.expansion 143 | for i in range(1, blocks): 144 | layers.append(block(self.inplanes, planes)) 145 | 146 | return nn.Sequential(*layers) 147 | 148 | def forward(self, x): 149 | if self.dataset == 'cifar10' or self.dataset == 'cifar100': 150 | x = self.conv1(x) 151 | x = self.bn1(x) 152 | x = self.relu(x) 153 | 154 | x = self.layer1(x) 155 | x = self.layer2(x) 156 | x = self.layer3(x) 157 | 158 | x = self.avgpool(x) 159 | x = x.view(x.size(0), -1) 160 | x = self.fc(x) 161 | 162 | elif self.dataset == 'imagenet': 163 | x = self.conv1(x) 164 | x = self.bn1(x) 165 | x = self.relu(x) 166 | x = self.maxpool(x) 167 | 168 | x = self.layer1(x) 169 | x = self.layer2(x) 170 | x = self.layer3(x) 171 | x = self.layer4(x) 172 | 173 | x = self.avgpool(x) 174 | x = x.view(x.size(0), -1) 175 | x = self.fc(x) 176 | 177 | return x 178 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # original code: https://github.com/dyhan0920/PyramidNet-PyTorch/blob/master/train.py 2 | 3 | import argparse 4 | import os 5 | import shutil 6 | import time 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim 13 | import torch.utils.data 14 | import torch.utils.data.distributed 15 | import torchvision.transforms as transforms 16 | import torchvision.datasets as datasets 17 | import torchvision.models as models 18 | import resnet as RN 19 | import pyramidnet as PYRM 20 | 21 | import warnings 22 | 23 | warnings.filterwarnings("ignore") 24 | 25 | model_names = sorted(name for name in models.__dict__ 26 | if name.islower() and not name.startswith("__") 27 | and callable(models.__dict__[name])) 28 | 29 | parser = argparse.ArgumentParser(description='Cutmix PyTorch CIFAR-10, CIFAR-100 and ImageNet-1k Test') 30 | parser.add_argument('--net_type', default='pyramidnet', type=str, 31 | help='networktype: resnet, and pyamidnet') 32 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 33 | help='number of data loading workers (default: 4)') 34 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 35 | help='number of total epochs to run') 36 | parser.add_argument('-b', '--batch_size', default=128, type=int, 37 | metavar='N', help='mini-batch size (default: 256)') 38 | parser.add_argument('--print-freq', '-p', default=1, type=int, 39 | metavar='N', help='print frequency (default: 10)') 40 | parser.add_argument('--depth', default=32, type=int, 41 | help='depth of the network (default: 32)') 42 | parser.add_argument('--no-bottleneck', dest='bottleneck', action='store_false', 43 | help='to use basicblock for CIFAR datasets (default: bottleneck)') 44 | parser.add_argument('--dataset', dest='dataset', default='imagenet', type=str, 45 | help='dataset (options: cifar10, cifar100, and imagenet)') 46 | parser.add_argument('--alpha', default=300, type=float, 47 | help='number of new channel increases per depth (default: 300)') 48 | parser.add_argument('--no-verbose', dest='verbose', action='store_false', 49 | help='to print the status at every iteration') 50 | parser.add_argument('--pretrained', default='/set/your/model/path', type=str, metavar='PATH') 51 | 52 | parser.set_defaults(bottleneck=True) 53 | parser.set_defaults(verbose=True) 54 | 55 | best_err1 = 100 56 | best_err5 = 100 57 | 58 | 59 | def main(): 60 | global args, best_err1, best_err5 61 | args = parser.parse_args() 62 | 63 | if args.dataset.startswith('cifar'): 64 | normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], 65 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) 66 | 67 | transform_train = transforms.Compose([ 68 | transforms.RandomCrop(32, padding=4), 69 | transforms.RandomHorizontalFlip(), 70 | transforms.ToTensor(), 71 | normalize, 72 | ]) 73 | 74 | transform_test = transforms.Compose([ 75 | transforms.ToTensor(), 76 | normalize 77 | ]) 78 | 79 | if args.dataset == 'cifar100': 80 | val_loader = torch.utils.data.DataLoader( 81 | datasets.CIFAR100('../data', train=False, transform=transform_test), 82 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 83 | numberofclass = 100 84 | elif args.dataset == 'cifar10': 85 | val_loader = torch.utils.data.DataLoader( 86 | datasets.CIFAR10('../data', train=False, transform=transform_test), 87 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 88 | numberofclass = 10 89 | else: 90 | raise Exception('unknown dataset: {}'.format(args.dataset)) 91 | 92 | elif args.dataset == 'imagenet': 93 | 94 | valdir = os.path.join('/home/data/ILSVRC/val') 95 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 96 | std=[0.229, 0.224, 0.225]) 97 | 98 | val_loader = torch.utils.data.DataLoader( 99 | datasets.ImageFolder(valdir, transforms.Compose([ 100 | transforms.Resize(256), 101 | transforms.CenterCrop(224), 102 | transforms.ToTensor(), 103 | normalize, 104 | ])), 105 | batch_size=args.batch_size, shuffle=False, 106 | num_workers=args.workers, pin_memory=True) 107 | numberofclass = 1000 108 | 109 | else: 110 | raise Exception('unknown dataset: {}'.format(args.dataset)) 111 | 112 | print("=> creating model '{}'".format(args.net_type)) 113 | if args.net_type == 'resnet': 114 | model = RN.ResNet(args.dataset, args.depth, numberofclass, args.bottleneck) # for ResNet 115 | elif args.net_type == 'pyramidnet': 116 | model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha, numberofclass, 117 | args.bottleneck) 118 | else: 119 | raise Exception('unknown network architecture: {}'.format(args.net_type)) 120 | 121 | model = torch.nn.DataParallel(model).cuda() 122 | 123 | if os.path.isfile(args.pretrained): 124 | print("=> loading checkpoint '{}'".format(args.pretrained)) 125 | checkpoint = torch.load(args.pretrained) 126 | model.load_state_dict(checkpoint['state_dict']) 127 | print("=> loaded checkpoint '{}'".format(args.pretrained)) 128 | else: 129 | raise Exception("=> no checkpoint found at '{}'".format(args.pretrained)) 130 | 131 | print(model) 132 | print('the number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 133 | 134 | # define loss function (criterion) and optimizer 135 | criterion = nn.CrossEntropyLoss().cuda() 136 | 137 | cudnn.benchmark = True 138 | 139 | # evaluate on validation set 140 | err1, err5, val_loss = validate(val_loader, model, criterion) 141 | 142 | print('Accuracy (top-1 and 5 error):', err1, err5) 143 | 144 | 145 | def validate(val_loader, model, criterion): 146 | batch_time = AverageMeter() 147 | losses = AverageMeter() 148 | top1 = AverageMeter() 149 | top5 = AverageMeter() 150 | 151 | # switch to evaluate mode 152 | model.eval() 153 | 154 | end = time.time() 155 | for i, (input, target) in enumerate(val_loader): 156 | target = target.cuda() 157 | 158 | output = model(input) 159 | loss = criterion(output, target) 160 | 161 | # measure accuracy and record loss 162 | err1, err5 = accuracy(output.data, target, topk=(1, 5)) 163 | 164 | losses.update(loss.item(), input.size(0)) 165 | 166 | top1.update(err1.item(), input.size(0)) 167 | top5.update(err5.item(), input.size(0)) 168 | 169 | # measure elapsed time 170 | batch_time.update(time.time() - end) 171 | end = time.time() 172 | 173 | if i % args.print_freq == 0 and args.verbose == True: 174 | print('Test (on val set): [{0}/{1}]\t' 175 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 176 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 177 | 'Top 1-err {top1.val:.4f} ({top1.avg:.4f})\t' 178 | 'Top 5-err {top5.val:.4f} ({top5.avg:.4f})'.format( 179 | i, len(val_loader), batch_time=batch_time, loss=losses, 180 | top1=top1, top5=top5)) 181 | 182 | return top1.avg, top5.avg, losses.avg 183 | 184 | 185 | class AverageMeter(object): 186 | """Computes and stores the average and current value""" 187 | 188 | def __init__(self): 189 | self.reset() 190 | 191 | def reset(self): 192 | self.val = 0 193 | self.avg = 0 194 | self.sum = 0 195 | self.count = 0 196 | 197 | def update(self, val, n=1): 198 | self.val = val 199 | self.sum += val * n 200 | self.count += n 201 | self.avg = self.sum / self.count 202 | 203 | 204 | def accuracy(output, target, topk=(1,)): 205 | """Computes the precision@k for the specified values of k""" 206 | maxk = max(topk) 207 | batch_size = target.size(0) 208 | 209 | _, pred = output.topk(maxk, 1, True, True) 210 | pred = pred.t() 211 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 212 | 213 | res = [] 214 | for k in topk: 215 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 216 | wrong_k = batch_size - correct_k 217 | res.append(wrong_k.mul_(100.0 / batch_size)) 218 | 219 | return res 220 | 221 | 222 | if __name__ == '__main__': 223 | main() 224 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # original code: https://github.com/dyhan0920/PyramidNet-PyTorch/blob/master/train.py 2 | 3 | import argparse 4 | import os 5 | import shutil 6 | import time 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim 13 | import torch.utils.data 14 | import torch.utils.data.distributed 15 | import torchvision.transforms as transforms 16 | import torchvision.datasets as datasets 17 | import torchvision.models as models 18 | import resnet as RN 19 | import pyramidnet as PYRM 20 | import utils 21 | import numpy as np 22 | 23 | import warnings 24 | 25 | warnings.filterwarnings("ignore") 26 | 27 | model_names = sorted(name for name in models.__dict__ 28 | if name.islower() and not name.startswith("__") 29 | and callable(models.__dict__[name])) 30 | 31 | parser = argparse.ArgumentParser(description='Cutmix PyTorch CIFAR-10, CIFAR-100 and ImageNet-1k Training') 32 | parser.add_argument('--net_type', default='pyramidnet', type=str, 33 | help='networktype: resnet, and pyamidnet') 34 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 35 | help='number of data loading workers (default: 4)') 36 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 37 | help='number of total epochs to run') 38 | parser.add_argument('-b', '--batch_size', default=128, type=int, 39 | metavar='N', help='mini-batch size (default: 256)') 40 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 41 | metavar='LR', help='initial learning rate') 42 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 43 | help='momentum') 44 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 45 | metavar='W', help='weight decay (default: 1e-4)') 46 | parser.add_argument('--print-freq', '-p', default=1, type=int, 47 | metavar='N', help='print frequency (default: 10)') 48 | parser.add_argument('--depth', default=32, type=int, 49 | help='depth of the network (default: 32)') 50 | parser.add_argument('--no-bottleneck', dest='bottleneck', action='store_false', 51 | help='to use basicblock for CIFAR datasets (default: bottleneck)') 52 | parser.add_argument('--dataset', dest='dataset', default='imagenet', type=str, 53 | help='dataset (options: cifar10, cifar100, and imagenet)') 54 | parser.add_argument('--no-verbose', dest='verbose', action='store_false', 55 | help='to print the status at every iteration') 56 | parser.add_argument('--alpha', default=300, type=float, 57 | help='number of new channel increases per depth (default: 300)') 58 | parser.add_argument('--expname', default='TEST', type=str, 59 | help='name of experiment') 60 | parser.add_argument('--beta', default=0, type=float, 61 | help='hyperparameter beta') 62 | parser.add_argument('--cutmix_prob', default=0, type=float, 63 | help='cutmix probability') 64 | 65 | parser.set_defaults(bottleneck=True) 66 | parser.set_defaults(verbose=True) 67 | 68 | best_err1 = 100 69 | best_err5 = 100 70 | 71 | 72 | def main(): 73 | global args, best_err1, best_err5 74 | args = parser.parse_args() 75 | 76 | if args.dataset.startswith('cifar'): 77 | normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], 78 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) 79 | 80 | transform_train = transforms.Compose([ 81 | transforms.RandomCrop(32, padding=4), 82 | transforms.RandomHorizontalFlip(), 83 | transforms.ToTensor(), 84 | normalize, 85 | ]) 86 | 87 | transform_test = transforms.Compose([ 88 | transforms.ToTensor(), 89 | normalize 90 | ]) 91 | 92 | if args.dataset == 'cifar100': 93 | train_loader = torch.utils.data.DataLoader( 94 | datasets.CIFAR100('../data', train=True, download=True, transform=transform_train), 95 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 96 | val_loader = torch.utils.data.DataLoader( 97 | datasets.CIFAR100('../data', train=False, transform=transform_test), 98 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 99 | numberofclass = 100 100 | elif args.dataset == 'cifar10': 101 | train_loader = torch.utils.data.DataLoader( 102 | datasets.CIFAR10('../data', train=True, download=True, transform=transform_train), 103 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 104 | val_loader = torch.utils.data.DataLoader( 105 | datasets.CIFAR10('../data', train=False, transform=transform_test), 106 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) 107 | numberofclass = 10 108 | else: 109 | raise Exception('unknown dataset: {}'.format(args.dataset)) 110 | 111 | elif args.dataset == 'imagenet': 112 | traindir = os.path.join('/home/data/ILSVRC/train') 113 | valdir = os.path.join('/home/data/ILSVRC/val') 114 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 115 | std=[0.229, 0.224, 0.225]) 116 | 117 | jittering = utils.ColorJitter(brightness=0.4, contrast=0.4, 118 | saturation=0.4) 119 | lighting = utils.Lighting(alphastd=0.1, 120 | eigval=[0.2175, 0.0188, 0.0045], 121 | eigvec=[[-0.5675, 0.7192, 0.4009], 122 | [-0.5808, -0.0045, -0.8140], 123 | [-0.5836, -0.6948, 0.4203]]) 124 | 125 | train_dataset = datasets.ImageFolder( 126 | traindir, 127 | transforms.Compose([ 128 | transforms.RandomResizedCrop(224), 129 | transforms.RandomHorizontalFlip(), 130 | transforms.ToTensor(), 131 | jittering, 132 | lighting, 133 | normalize, 134 | ])) 135 | 136 | train_sampler = None 137 | 138 | train_loader = torch.utils.data.DataLoader( 139 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 140 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 141 | 142 | val_loader = torch.utils.data.DataLoader( 143 | datasets.ImageFolder(valdir, transforms.Compose([ 144 | transforms.Resize(256), 145 | transforms.CenterCrop(224), 146 | transforms.ToTensor(), 147 | normalize, 148 | ])), 149 | batch_size=args.batch_size, shuffle=False, 150 | num_workers=args.workers, pin_memory=True) 151 | numberofclass = 1000 152 | 153 | else: 154 | raise Exception('unknown dataset: {}'.format(args.dataset)) 155 | 156 | print("=> creating model '{}'".format(args.net_type)) 157 | if args.net_type == 'resnet': 158 | model = RN.ResNet(args.dataset, args.depth, numberofclass, args.bottleneck) # for ResNet 159 | elif args.net_type == 'pyramidnet': 160 | model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha, numberofclass, 161 | args.bottleneck) 162 | else: 163 | raise Exception('unknown network architecture: {}'.format(args.net_type)) 164 | 165 | model = torch.nn.DataParallel(model).cuda() 166 | 167 | print(model) 168 | print('the number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 169 | 170 | # define loss function (criterion) and optimizer 171 | criterion = nn.CrossEntropyLoss().cuda() 172 | 173 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 174 | momentum=args.momentum, 175 | weight_decay=args.weight_decay, nesterov=True) 176 | 177 | cudnn.benchmark = True 178 | 179 | for epoch in range(0, args.epochs): 180 | 181 | adjust_learning_rate(optimizer, epoch) 182 | 183 | # train for one epoch 184 | train_loss = train(train_loader, model, criterion, optimizer, epoch) 185 | 186 | # evaluate on validation set 187 | err1, err5, val_loss = validate(val_loader, model, criterion, epoch) 188 | 189 | # remember best prec@1 and save checkpoint 190 | is_best = err1 <= best_err1 191 | best_err1 = min(err1, best_err1) 192 | if is_best: 193 | best_err5 = err5 194 | 195 | print('Current best accuracy (top-1 and 5 error):', best_err1, best_err5) 196 | save_checkpoint({ 197 | 'epoch': epoch, 198 | 'arch': args.net_type, 199 | 'state_dict': model.state_dict(), 200 | 'best_err1': best_err1, 201 | 'best_err5': best_err5, 202 | 'optimizer': optimizer.state_dict(), 203 | }, is_best) 204 | 205 | print('Best accuracy (top-1 and 5 error):', best_err1, best_err5) 206 | 207 | 208 | def train(train_loader, model, criterion, optimizer, epoch): 209 | batch_time = AverageMeter() 210 | data_time = AverageMeter() 211 | losses = AverageMeter() 212 | top1 = AverageMeter() 213 | top5 = AverageMeter() 214 | 215 | # switch to train mode 216 | model.train() 217 | 218 | end = time.time() 219 | current_LR = get_learning_rate(optimizer)[0] 220 | for i, (input, target) in enumerate(train_loader): 221 | # measure data loading time 222 | data_time.update(time.time() - end) 223 | 224 | input = input.cuda() 225 | target = target.cuda() 226 | 227 | r = np.random.rand(1) 228 | if args.beta > 0 and r < args.cutmix_prob: 229 | # generate mixed sample 230 | lam = np.random.beta(args.beta, args.beta) 231 | rand_index = torch.randperm(input.size()[0]).cuda() 232 | target_a = target 233 | target_b = target[rand_index] 234 | bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam) 235 | input[:, :, bbx1:bbx2, bby1:bby2] = input[rand_index, :, bbx1:bbx2, bby1:bby2] 236 | # adjust lambda to exactly match pixel ratio 237 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2])) 238 | # compute output 239 | output = model(input) 240 | loss = criterion(output, target_a) * lam + criterion(output, target_b) * (1. - lam) 241 | else: 242 | # compute output 243 | output = model(input) 244 | loss = criterion(output, target) 245 | 246 | # measure accuracy and record loss 247 | err1, err5 = accuracy(output.data, target, topk=(1, 5)) 248 | 249 | losses.update(loss.item(), input.size(0)) 250 | top1.update(err1.item(), input.size(0)) 251 | top5.update(err5.item(), input.size(0)) 252 | 253 | # compute gradient and do SGD step 254 | optimizer.zero_grad() 255 | loss.backward() 256 | optimizer.step() 257 | 258 | # measure elapsed time 259 | batch_time.update(time.time() - end) 260 | end = time.time() 261 | 262 | if i % args.print_freq == 0 and args.verbose == True: 263 | print('Epoch: [{0}/{1}][{2}/{3}]\t' 264 | 'LR: {LR:.6f}\t' 265 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 266 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 267 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 268 | 'Top 1-err {top1.val:.4f} ({top1.avg:.4f})\t' 269 | 'Top 5-err {top5.val:.4f} ({top5.avg:.4f})'.format( 270 | epoch, args.epochs, i, len(train_loader), LR=current_LR, batch_time=batch_time, 271 | data_time=data_time, loss=losses, top1=top1, top5=top5)) 272 | 273 | print('* Epoch: [{0}/{1}]\t Top 1-err {top1.avg:.3f} Top 5-err {top5.avg:.3f}\t Train Loss {loss.avg:.3f}'.format( 274 | epoch, args.epochs, top1=top1, top5=top5, loss=losses)) 275 | 276 | return losses.avg 277 | 278 | 279 | def rand_bbox(size, lam): 280 | W = size[2] 281 | H = size[3] 282 | cut_rat = np.sqrt(1. - lam) 283 | cut_w = np.int(W * cut_rat) 284 | cut_h = np.int(H * cut_rat) 285 | 286 | # uniform 287 | cx = np.random.randint(W) 288 | cy = np.random.randint(H) 289 | 290 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 291 | bby1 = np.clip(cy - cut_h // 2, 0, H) 292 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 293 | bby2 = np.clip(cy + cut_h // 2, 0, H) 294 | 295 | return bbx1, bby1, bbx2, bby2 296 | 297 | 298 | def validate(val_loader, model, criterion, epoch): 299 | batch_time = AverageMeter() 300 | losses = AverageMeter() 301 | top1 = AverageMeter() 302 | top5 = AverageMeter() 303 | 304 | # switch to evaluate mode 305 | model.eval() 306 | 307 | end = time.time() 308 | for i, (input, target) in enumerate(val_loader): 309 | target = target.cuda() 310 | 311 | output = model(input) 312 | loss = criterion(output, target) 313 | 314 | # measure accuracy and record loss 315 | err1, err5 = accuracy(output.data, target, topk=(1, 5)) 316 | 317 | losses.update(loss.item(), input.size(0)) 318 | 319 | top1.update(err1.item(), input.size(0)) 320 | top5.update(err5.item(), input.size(0)) 321 | 322 | # measure elapsed time 323 | batch_time.update(time.time() - end) 324 | end = time.time() 325 | 326 | if i % args.print_freq == 0 and args.verbose == True: 327 | print('Test (on val set): [{0}/{1}][{2}/{3}]\t' 328 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 329 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 330 | 'Top 1-err {top1.val:.4f} ({top1.avg:.4f})\t' 331 | 'Top 5-err {top5.val:.4f} ({top5.avg:.4f})'.format( 332 | epoch, args.epochs, i, len(val_loader), batch_time=batch_time, loss=losses, 333 | top1=top1, top5=top5)) 334 | 335 | print('* Epoch: [{0}/{1}]\t Top 1-err {top1.avg:.3f} Top 5-err {top5.avg:.3f}\t Test Loss {loss.avg:.3f}'.format( 336 | epoch, args.epochs, top1=top1, top5=top5, loss=losses)) 337 | return top1.avg, top5.avg, losses.avg 338 | 339 | 340 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 341 | directory = "runs/%s/" % (args.expname) 342 | if not os.path.exists(directory): 343 | os.makedirs(directory) 344 | filename = directory + filename 345 | torch.save(state, filename) 346 | if is_best: 347 | shutil.copyfile(filename, 'runs/%s/' % (args.expname) + 'model_best.pth.tar') 348 | 349 | 350 | class AverageMeter(object): 351 | """Computes and stores the average and current value""" 352 | 353 | def __init__(self): 354 | self.reset() 355 | 356 | def reset(self): 357 | self.val = 0 358 | self.avg = 0 359 | self.sum = 0 360 | self.count = 0 361 | 362 | def update(self, val, n=1): 363 | self.val = val 364 | self.sum += val * n 365 | self.count += n 366 | self.avg = self.sum / self.count 367 | 368 | 369 | def adjust_learning_rate(optimizer, epoch): 370 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 371 | if args.dataset.startswith('cifar'): 372 | lr = args.lr * (0.1 ** (epoch // (args.epochs * 0.5))) * (0.1 ** (epoch // (args.epochs * 0.75))) 373 | elif args.dataset == ('imagenet'): 374 | if args.epochs == 300: 375 | lr = args.lr * (0.1 ** (epoch // 75)) 376 | else: 377 | lr = args.lr * (0.1 ** (epoch // 30)) 378 | 379 | for param_group in optimizer.param_groups: 380 | param_group['lr'] = lr 381 | 382 | 383 | def get_learning_rate(optimizer): 384 | lr = [] 385 | for param_group in optimizer.param_groups: 386 | lr += [param_group['lr']] 387 | return lr 388 | 389 | 390 | def accuracy(output, target, topk=(1,)): 391 | """Computes the precision@k for the specified values of k""" 392 | maxk = max(topk) 393 | batch_size = target.size(0) 394 | 395 | _, pred = output.topk(maxk, 1, True, True) 396 | pred = pred.t() 397 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 398 | 399 | res = [] 400 | for k in topk: 401 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 402 | wrong_k = batch_size - correct_k 403 | res.append(wrong_k.mul_(100.0 / batch_size)) 404 | 405 | return res 406 | 407 | 408 | if __name__ == '__main__': 409 | main() 410 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # original code: https://github.com/eladhoffer/convNet.pytorch/blob/master/preprocess.py 2 | 3 | import torch 4 | import random 5 | 6 | __all__ = ["Compose", "Lighting", "ColorJitter"] 7 | 8 | 9 | class Compose(object): 10 | """Composes several transforms together. 11 | 12 | Args: 13 | transforms (list of ``Transform`` objects): list of transforms to compose. 14 | 15 | Example: 16 | >>> transforms.Compose([ 17 | >>> transforms.CenterCrop(10), 18 | >>> transforms.ToTensor(), 19 | >>> ]) 20 | """ 21 | 22 | def __init__(self, transforms): 23 | self.transforms = transforms 24 | 25 | def __call__(self, img): 26 | for t in self.transforms: 27 | img = t(img) 28 | return img 29 | 30 | def __repr__(self): 31 | format_string = self.__class__.__name__ + '(' 32 | for t in self.transforms: 33 | format_string += '\n' 34 | format_string += ' {0}'.format(t) 35 | format_string += '\n)' 36 | return format_string 37 | 38 | 39 | class Lighting(object): 40 | """Lighting noise(AlexNet - style PCA - based noise)""" 41 | 42 | def __init__(self, alphastd, eigval, eigvec): 43 | self.alphastd = alphastd 44 | self.eigval = torch.Tensor(eigval) 45 | self.eigvec = torch.Tensor(eigvec) 46 | 47 | def __call__(self, img): 48 | if self.alphastd == 0: 49 | return img 50 | 51 | alpha = img.new().resize_(3).normal_(0, self.alphastd) 52 | rgb = self.eigvec.type_as(img).clone() \ 53 | .mul(alpha.view(1, 3).expand(3, 3)) \ 54 | .mul(self.eigval.view(1, 3).expand(3, 3)) \ 55 | .sum(1).squeeze() 56 | 57 | return img.add(rgb.view(3, 1, 1).expand_as(img)) 58 | 59 | 60 | class Grayscale(object): 61 | 62 | def __call__(self, img): 63 | gs = img.clone() 64 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2]) 65 | gs[1].copy_(gs[0]) 66 | gs[2].copy_(gs[0]) 67 | return gs 68 | 69 | 70 | class Saturation(object): 71 | 72 | def __init__(self, var): 73 | self.var = var 74 | 75 | def __call__(self, img): 76 | gs = Grayscale()(img) 77 | alpha = random.uniform(-self.var, self.var) 78 | return img.lerp(gs, alpha) 79 | 80 | 81 | class Brightness(object): 82 | 83 | def __init__(self, var): 84 | self.var = var 85 | 86 | def __call__(self, img): 87 | gs = img.new().resize_as_(img).zero_() 88 | alpha = random.uniform(-self.var, self.var) 89 | return img.lerp(gs, alpha) 90 | 91 | 92 | class Contrast(object): 93 | 94 | def __init__(self, var): 95 | self.var = var 96 | 97 | def __call__(self, img): 98 | gs = Grayscale()(img) 99 | gs.fill_(gs.mean()) 100 | alpha = random.uniform(-self.var, self.var) 101 | return img.lerp(gs, alpha) 102 | 103 | 104 | class ColorJitter(object): 105 | 106 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4): 107 | self.brightness = brightness 108 | self.contrast = contrast 109 | self.saturation = saturation 110 | 111 | def __call__(self, img): 112 | self.transforms = [] 113 | if self.brightness != 0: 114 | self.transforms.append(Brightness(self.brightness)) 115 | if self.contrast != 0: 116 | self.transforms.append(Contrast(self.contrast)) 117 | if self.saturation != 0: 118 | self.transforms.append(Saturation(self.saturation)) 119 | 120 | random.shuffle(self.transforms) 121 | transform = Compose(self.transforms) 122 | # print(transform) 123 | return transform(img) 124 | --------------------------------------------------------------------------------