├── .gitignore
├── LICENSE
├── NOTICE
├── README.md
├── img1.PNG
├── pyramidnet.py
├── resnet.py
├── test.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | .idea/
3 | runs/
4 | train_nsml.py
5 | requirements.txt
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2019-present NAVER Corp.
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | CutMix
2 | Copyright (c) 2019-present NAVER Corp.
3 |
4 | This project contains subcomponents with separate copyright notices and license terms.
5 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
6 |
7 | =======================================================================
8 | pytorch/vision from https://github.com/pytorch/vision
9 | =======================================================================
10 |
11 | BSD 3-Clause License
12 |
13 | Copyright (c) Soumith Chintala 2016,
14 | All rights reserved.
15 |
16 | Redistribution and use in source and binary forms, with or without
17 | modification, are permitted provided that the following conditions are met:
18 |
19 | * Redistributions of source code must retain the above copyright notice, this
20 | list of conditions and the following disclaimer.
21 |
22 | * Redistributions in binary form must reproduce the above copyright notice,
23 | this list of conditions and the following disclaimer in the documentation
24 | and/or other materials provided with the distribution.
25 |
26 | * Neither the name of the copyright holder nor the names of its
27 | contributors may be used to endorse or promote products derived from
28 | this software without specific prior written permission.
29 |
30 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
31 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
32 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
33 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
34 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
35 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
36 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
37 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
38 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
39 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
40 |
41 | =======================================================================
42 | dyhan0920/PyramidNet-PyTorch from https://github.com/dyhan0920/PyramidNet-PyTorch
43 | =======================================================================
44 |
45 | MIT License
46 |
47 | Copyright (c) 2019 Dongyoon Han
48 |
49 | Permission is hereby granted, free of charge, to any person obtaining a copy
50 | of this software and associated documentation files (the "Software"), to deal
51 | in the Software without restriction, including without limitation the rights
52 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
53 | copies of the Software, and to permit persons to whom the Software is
54 | furnished to do so, subject to the following conditions:
55 |
56 | The above copyright notice and this permission notice shall be included in all
57 | copies or substantial portions of the Software.
58 |
59 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
60 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
61 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
62 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
63 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
64 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
65 | SOFTWARE.
66 |
67 | =======================================================================
68 | eladhoffer/convNet.pytorch from https://github.com/eladhoffer/convNet.pytorch
69 | =======================================================================
70 |
71 | MIT License
72 |
73 | Copyright (c) 2017 Elad Hoffer
74 |
75 | Permission is hereby granted, free of charge, to any person obtaining a copy
76 | of this software and associated documentation files (the "Software"), to deal
77 | in the Software without restriction, including without limitation the rights
78 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
79 | copies of the Software, and to permit persons to whom the Software is
80 | furnished to do so, subject to the following conditions:
81 |
82 | The above copyright notice and this permission notice shall be included in all
83 | copies or substantial portions of the Software.
84 |
85 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
86 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
87 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
88 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
89 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
90 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
91 | SOFTWARE.
92 |
93 | =====
94 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Accepted at ICCV 2019 (oral talk) !!
2 |
3 | ## CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features
4 |
5 | Official Pytorch implementation of CutMix regularizer | [Paper](https://arxiv.org/abs/1905.04899) | [Pretrained Models](#experiments)
6 |
7 | **[Sangdoo Yun](mailto:sangdoo.yun@navercorp.com), Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo.**
8 |
9 | Clova AI Research, NAVER Corp.
10 |
11 | Our implementation is based on these repositories:
12 | - [PyTorch ImageNet Example](https://github.com/pytorch/examples/tree/master/imagenet)
13 | - [PyramidNet-PyTorch](https://github.com/dyhan0920/PyramidNet-PyTorch)
14 |
15 |
16 | ### Abstract
17 | Regional dropout strategies have been proposed to enhance the performance of convolutional neural network classifiers.
18 | They have proved to be effective for guiding the model to attend on less discriminative parts of objects
19 | (e.g. leg as opposed to head of a person), thereby letting the network generalize better and have better object localization capabilities.
20 | On the other hand, current methods for regional dropout removes informative pixels on training images by overlaying a patch of either black pixels or random noise.
21 | Such removal is not desirable because it leads to information loss and inefficiency during training.
22 | We therefore propose the **CutMix** augmentation strategy: patches are cut and pasted among training images where the ground truth labels are also mixed proportionally to the area of the patches.
23 | By making efficient use of training pixels and retaining the regularization effect of regional dropout, CutMix consistently outperforms the state-of-the-art augmentation strategies on CIFAR and ImageNet classification tasks, as well as on the ImageNet weakly-supervised localization task.
24 | Moreover, unlike previous augmentation methods, our CutMix-trained ImageNet classifier, when used as a pretrained model, results in consistent performance gains in Pascal detection and MS-COCO image captioning benchmarks.
25 | We also show that CutMix improves the model robustness against input corruptions and its out-of-distribution detection performances.
26 |
27 |
28 | ### Overview of the results of Mixup, Cutout, and CutMix.
29 |
30 |
31 |
32 | ## Updates
33 | **23 May, 2019**: Initial upload
34 |
35 | ## Getting Started
36 | ### Requirements
37 | - Python3
38 | - PyTorch (> 1.0)
39 | - torchvision (> 0.2)
40 | - NumPy
41 |
42 | ### Train Examples
43 | - CIFAR-100: We used 2 GPUs to train CIFAR-100.
44 | ```
45 | python train.py \
46 | --net_type pyramidnet \
47 | --dataset cifar100 \
48 | --depth 200 \
49 | --alpha 240 \
50 | --batch_size 64 \
51 | --lr 0.25 \
52 | --expname PyraNet200 \
53 | --epochs 300 \
54 | --beta 1.0 \
55 | --cutmix_prob 0.5 \
56 | --no-verbose
57 | ```
58 | - ImageNet: We used 4 GPUs to train ImageNet.
59 | ```
60 | python train.py \
61 | --net_type resnet \
62 | --dataset imagenet \
63 | --batch_size 256 \
64 | --lr 0.1 \
65 | --depth 50 \
66 | --epochs 300 \
67 | --expname ResNet50 \
68 | -j 40 \
69 | --beta 1.0 \
70 | --cutmix_prob 1.0 \
71 | --no-verbose
72 | ```
73 |
74 | ### Test Examples using Pretrained model
75 | - Download [CutMix-pretrained PyramidNet200 (top-1 error: 14.23)](https://www.dropbox.com/sh/o68qbvayptt2rz5/AACy3o779BxoRqw6_GQf_QFQa?dl=0)
76 | ```
77 | python test.py \
78 | --net_type pyramidnet \
79 | --dataset cifar100 \
80 | --batch_size 64 \
81 | --depth 200 \
82 | --alpha 240 \
83 | --pretrained /set/your/model/path/model_best.pth.tar
84 | ```
85 | - Download [CutMix-pretrained ResNet50 (top-1 error: 21.40)](https://www.dropbox.com/sh/w8dvfgdc3eirivf/AABnGcTO9wao9xVGWwqsXRala?dl=0)
86 | ```
87 | python test.py \
88 | --net_type resnet \
89 | --dataset imagenet \
90 | --batch_size 64 \
91 | --depth 50 \
92 | --pretrained /set/your/model/path/model_best.pth.tar
93 | ```
94 |
95 |
Experimental Results and Pretrained Models
96 |
97 | - PyramidNet-200 pretrained on CIFAR-100 dataset:
98 |
99 | Method | Top-1 Error | Model file
100 | -- | -- | --
101 | PyramidNet-200 [[CVPR'17](https://arxiv.org/abs/1610.02915)] (baseline) | 16.45 | [model](https://www.dropbox.com/sh/6rfew3lr761jq6c/AADrdQOXNx5tWmgOSnAw9NEVa?dl=0)
102 | PyramidNet-200 + **CutMix** | **14.23** | [model](https://www.dropbox.com/sh/o68qbvayptt2rz5/AACy3o779BxoRqw6_GQf_QFQa?dl=0)
103 | PyramidNet-200 + Shakedrop [[arXiv'18](https://arxiv.org/abs/1802.02375)] + **CutMix** | **13.81** | -
104 | PyramidNet-200 + Mixup [[ICLR'18](https://arxiv.org/abs/1710.09412)] | 15.63 | [model](https://www.dropbox.com/sh/g55jnsv62v0n59s/AAC9LPg-LjlnBn4ttKs6vr7Ka?dl=0)
105 | PyramidNet-200 + Manifold Mixup [[ICML'19](https://arxiv.org/abs/1806.05236)] | 16.14 | [model](https://www.dropbox.com/sh/nngw7hhk1e8msbr/AABkdCsP0ABnQJDBX7LQVj4la?dl=0)
106 | PyramidNet-200 + Cutout [[arXiv'17](https://arxiv.org/abs/1708.04552)] | 16.53 | [model](https://www.dropbox.com/sh/ajjz4q8c8t6qva9/AAAeBGb2Q4TnJMW0JAzeVSpfa?dl=0)
107 | PyramidNet-200 + DropBlock [[NeurIPS'18](https://arxiv.org/abs/1810.12890)] | 15.73 | [model](https://www.dropbox.com/sh/vefjo960gyrsx2i/AACYA5wOJ_yroNjIjdsN1Dz2a?dl=0)
108 | PyramidNet-200 + Cutout + Labelsmoothing | 15.61 | [model](https://www.dropbox.com/sh/1mur0kjcfxdn7jn/AADmghqrj0dXAG0qY1v3Csb6a?dl=0)
109 | PyramidNet-200 + DropBlock + Labelsmoothing | 15.16 | [model](https://www.dropbox.com/sh/n1dn6ggyxjcoogc/AADpSSNzvaraSCqWtHBE0qMca?dl=0)
110 | PyramidNet-200 + Cutout + Mixup | 15.46 | [model](https://www.dropbox.com/sh/5run1sx8oy0v9oi/AACiR_wEBQVp2HMZFx6lGl3ka?dl=0)
111 |
112 |
113 | - ResNet models pretrained on ImageNet dataset:
114 |
115 | Method | Top-1 Error | Model file
116 | -- | -- | --
117 | ResNet-50 [[CVPR'16](https://arxiv.org/abs/1512.03385)] (baseline) | 23.68 | [model](https://www.dropbox.com/sh/phwbbrtadrclpnx/AAA9QUW9G_xvBdI-mDiIzP_Ha?dl=0)
118 | ResNet-50 + **CutMix** | **21.40** | [model](https://www.dropbox.com/sh/w8dvfgdc3eirivf/AABnGcTO9wao9xVGWwqsXRala?dl=0)
119 | ResNet-50 + **Feature CutMix** | **21.80** | [model](https://www.dropbox.com/sh/zj1wptsg0hwqf0k/AABRNzvjFmIS7_vOEQkqb6T4a?dl=0)
120 | ResNet-50 + Mixup [[ICLR'18](https://arxiv.org/abs/1710.09412)] | 22.58 | [model](https://www.dropbox.com/sh/g64c8bda61n12if/AACyaTZnku_Sgibc9UvOSblNa?dl=0)
121 | ResNet-50 + Manifold Mixup [[ICML'19](https://arxiv.org/abs/1806.05236)] | 22.50 | [model](https://www.dropbox.com/sh/bjardjje11pti0g/AABFGW0gNrNE8o8TqUf4-SYSa?dl=0)
122 | ResNet-50 + Cutout [[arXiv'17](https://arxiv.org/abs/1708.04552)] | 22.93 | [model](https://www.dropbox.com/sh/ln8zk2z7zt2h1en/AAA7z8xTBlzz7Ofbd5L7oTnTa?dl=0)
123 | ResNet-50 + AutoAugment [[CVPR'19](https://arxiv.org/abs/1805.09501)] | 22.40* | -
124 | ResNet-50 + DropBlock [[NeurIPS'18](https://arxiv.org/abs/1810.12890)] | 21.87* | -
125 | ResNet-101 + **CutMix** | **20.17** | [model](https://www.dropbox.com/sh/1z4xnp9nwdmpzb5/AACQX4KU8XkTN0JSTfjkCktNa?dl=0)
126 | ResNet-152 + **CutMix** | **19.20** | [model](https://www.dropbox.com/s/6vq1mzy27z8qxko/resnet152_cutmix_acc_80_80.pth?dl=0)
127 | ResNeXt-101 (32x4d) + **CutMix** | **19.47** | [model](https://www.dropbox.com/s/maysvgopsi17qi0/resnext_cutmix.pth.tar?dl=0)
128 |
129 | \* denotes results reported in the original papers
130 |
131 | ## Transfer Learning Results
132 |
133 | Backbone | ImageNet Cls (%) | ImageNet Loc (%) | CUB200 Loc (%) | Detection (SSD) (mAP) | Detection (Faster-RCNN) (mAP) | Image Captioning (BLEU-4)
134 | -- | -- | -- | -- | -- | -- | --
135 | ResNet50 | 23.68 | 46.3 | 49.41 | 76.7 | 75.6 | 22.9
136 | ResNet50+Mixup | 22.58 | 45.84 | 49.3 | 76.6 | 73.9 | 23.2
137 | ResNet50+Cutout | 22.93 | 46.69 | 52.78 | 76.8 | 75 | 24.0
138 | ResNet50+**CutMix** | **21.60** | **46.25** | **54.81** | **77.6** | **76.7** | **24.9**
139 |
140 |
141 | ## Third-party Implementations
142 | - [Pytorch-CutMix](https://github.com/hysts/pytorch_cutmix) by @hysts
143 | - [TensorFlow-CutMix](https://github.com/jis478/Tensorflow/tree/master/TF2.0/Cutmix) by @jis478
144 |
145 | ## Citation
146 | ```
147 | @inproceedings{yun2019cutmix,
148 | title={CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features},
149 | author={Yun, Sangdoo and Han, Dongyoon and Oh, Seong Joon and Chun, Sanghyuk and Choe, Junsuk and Yoo, Youngjoon},
150 | booktitle = {International Conference on Computer Vision (ICCV)},
151 | year={2019},
152 | pubstate={published},
153 | tppubtype={inproceedings}
154 | }
155 | ```
156 |
157 | ## License
158 | ```
159 | Copyright (c) 2019-present NAVER Corp.
160 |
161 | Permission is hereby granted, free of charge, to any person obtaining a copy
162 | of this software and associated documentation files (the "Software"), to deal
163 | in the Software without restriction, including without limitation the rights
164 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
165 | copies of the Software, and to permit persons to whom the Software is
166 | furnished to do so, subject to the following conditions:
167 |
168 | The above copyright notice and this permission notice shall be included in
169 | all copies or substantial portions of the Software.
170 |
171 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
172 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
173 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
174 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
175 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
176 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
177 | THE SOFTWARE.
178 | ```
179 |
--------------------------------------------------------------------------------
/img1.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/CutMix-PyTorch/2d8eb68faff7fe4962776ad51d175c3b01a25734/img1.PNG
--------------------------------------------------------------------------------
/pyramidnet.py:
--------------------------------------------------------------------------------
1 | # Original code: https://github.com/dyhan0920/PyramidNet-PyTorch/blob/master/PyramidNet.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import math
6 |
7 | def conv3x3(in_planes, out_planes, stride=1):
8 | "3x3 convolution with padding"
9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
10 | padding=1, bias=False)
11 |
12 |
13 | class BasicBlock(nn.Module):
14 | outchannel_ratio = 1
15 |
16 | def __init__(self, inplanes, planes, stride=1, downsample=None):
17 | super(BasicBlock, self).__init__()
18 | self.bn1 = nn.BatchNorm2d(inplanes)
19 | self.conv1 = conv3x3(inplanes, planes, stride)
20 | self.bn2 = nn.BatchNorm2d(planes)
21 | self.conv2 = conv3x3(planes, planes)
22 | self.bn3 = nn.BatchNorm2d(planes)
23 | self.relu = nn.ReLU(inplace=True)
24 | self.downsample = downsample
25 | self.stride = stride
26 |
27 | def forward(self, x):
28 |
29 | out = self.bn1(x)
30 | out = self.conv1(out)
31 | out = self.bn2(out)
32 | out = self.relu(out)
33 | out = self.conv2(out)
34 | out = self.bn3(out)
35 | if self.downsample is not None:
36 | shortcut = self.downsample(x)
37 | featuremap_size = shortcut.size()[2:4]
38 | else:
39 | shortcut = x
40 | featuremap_size = out.size()[2:4]
41 |
42 | batch_size = out.size()[0]
43 | residual_channel = out.size()[1]
44 | shortcut_channel = shortcut.size()[1]
45 |
46 | if residual_channel != shortcut_channel:
47 | padding = torch.autograd.Variable(torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], featuremap_size[1]).fill_(0))
48 | out += torch.cat((shortcut, padding), 1)
49 | else:
50 | out += shortcut
51 |
52 | return out
53 |
54 |
55 | class Bottleneck(nn.Module):
56 | outchannel_ratio = 4
57 |
58 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16):
59 | super(Bottleneck, self).__init__()
60 | self.bn1 = nn.BatchNorm2d(inplanes)
61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
62 | self.bn2 = nn.BatchNorm2d(planes)
63 | self.conv2 = nn.Conv2d(planes, (planes), kernel_size=3, stride=stride, padding=1, bias=False, groups=1)
64 | self.bn3 = nn.BatchNorm2d((planes))
65 | self.conv3 = nn.Conv2d((planes), planes * Bottleneck.outchannel_ratio, kernel_size=1, bias=False)
66 | self.bn4 = nn.BatchNorm2d(planes * Bottleneck.outchannel_ratio)
67 | self.relu = nn.ReLU(inplace=True)
68 |
69 | self.downsample = downsample
70 | self.stride = stride
71 |
72 | def forward(self, x):
73 |
74 | out = self.bn1(x)
75 | out = self.conv1(out)
76 |
77 | out = self.bn2(out)
78 | out = self.relu(out)
79 | out = self.conv2(out)
80 |
81 | out = self.bn3(out)
82 | out = self.relu(out)
83 | out = self.conv3(out)
84 |
85 | out = self.bn4(out)
86 | if self.downsample is not None:
87 | shortcut = self.downsample(x)
88 | featuremap_size = shortcut.size()[2:4]
89 | else:
90 | shortcut = x
91 | featuremap_size = out.size()[2:4]
92 |
93 | batch_size = out.size()[0]
94 | residual_channel = out.size()[1]
95 | shortcut_channel = shortcut.size()[1]
96 |
97 | if residual_channel != shortcut_channel:
98 | padding = torch.autograd.Variable(torch.cuda.FloatTensor(batch_size, residual_channel - shortcut_channel, featuremap_size[0], featuremap_size[1]).fill_(0))
99 | out += torch.cat((shortcut, padding), 1)
100 | else:
101 | out += shortcut
102 |
103 | return out
104 |
105 |
106 | class PyramidNet(nn.Module):
107 |
108 | def __init__(self, dataset, depth, alpha, num_classes, bottleneck=False):
109 | super(PyramidNet, self).__init__()
110 | self.dataset = dataset
111 | if self.dataset.startswith('cifar'):
112 | self.inplanes = 16
113 | if bottleneck == True:
114 | n = int((depth - 2) / 9)
115 | block = Bottleneck
116 | else:
117 | n = int((depth - 2) / 6)
118 | block = BasicBlock
119 |
120 | self.addrate = alpha / (3*n*1.0)
121 |
122 | self.input_featuremap_dim = self.inplanes
123 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=3, stride=1, padding=1, bias=False)
124 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim)
125 |
126 | self.featuremap_dim = self.input_featuremap_dim
127 | self.layer1 = self.pyramidal_make_layer(block, n)
128 | self.layer2 = self.pyramidal_make_layer(block, n, stride=2)
129 | self.layer3 = self.pyramidal_make_layer(block, n, stride=2)
130 |
131 | self.final_featuremap_dim = self.input_featuremap_dim
132 | self.bn_final= nn.BatchNorm2d(self.final_featuremap_dim)
133 | self.relu_final = nn.ReLU(inplace=True)
134 | self.avgpool = nn.AvgPool2d(8)
135 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes)
136 |
137 | elif dataset == 'imagenet':
138 | blocks ={18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck}
139 | layers ={18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]}
140 |
141 | if layers.get(depth) is None:
142 | if bottleneck == True:
143 | blocks[depth] = Bottleneck
144 | temp_cfg = int((depth-2)/12)
145 | else:
146 | blocks[depth] = BasicBlock
147 | temp_cfg = int((depth-2)/8)
148 |
149 | layers[depth]= [temp_cfg, temp_cfg, temp_cfg, temp_cfg]
150 | print('=> the layer configuration for each stage is set to', layers[depth])
151 |
152 | self.inplanes = 64
153 | self.addrate = alpha / (sum(layers[depth])*1.0)
154 |
155 | self.input_featuremap_dim = self.inplanes
156 | self.conv1 = nn.Conv2d(3, self.input_featuremap_dim, kernel_size=7, stride=2, padding=3, bias=False)
157 | self.bn1 = nn.BatchNorm2d(self.input_featuremap_dim)
158 | self.relu = nn.ReLU(inplace=True)
159 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
160 |
161 | self.featuremap_dim = self.input_featuremap_dim
162 | self.layer1 = self.pyramidal_make_layer(blocks[depth], layers[depth][0])
163 | self.layer2 = self.pyramidal_make_layer(blocks[depth], layers[depth][1], stride=2)
164 | self.layer3 = self.pyramidal_make_layer(blocks[depth], layers[depth][2], stride=2)
165 | self.layer4 = self.pyramidal_make_layer(blocks[depth], layers[depth][3], stride=2)
166 |
167 | self.final_featuremap_dim = self.input_featuremap_dim
168 | self.bn_final= nn.BatchNorm2d(self.final_featuremap_dim)
169 | self.relu_final = nn.ReLU(inplace=True)
170 | self.avgpool = nn.AvgPool2d(7)
171 | self.fc = nn.Linear(self.final_featuremap_dim, num_classes)
172 |
173 | for m in self.modules():
174 | if isinstance(m, nn.Conv2d):
175 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
176 | m.weight.data.normal_(0, math.sqrt(2. / n))
177 | elif isinstance(m, nn.BatchNorm2d):
178 | m.weight.data.fill_(1)
179 | m.bias.data.zero_()
180 |
181 | def pyramidal_make_layer(self, block, block_depth, stride=1):
182 | downsample = None
183 | if stride != 1: # or self.inplanes != int(round(featuremap_dim_1st)) * block.outchannel_ratio:
184 | downsample = nn.AvgPool2d((2,2), stride = (2, 2), ceil_mode=True)
185 |
186 | layers = []
187 | self.featuremap_dim = self.featuremap_dim + self.addrate
188 | layers.append(block(self.input_featuremap_dim, int(round(self.featuremap_dim)), stride, downsample))
189 | for i in range(1, block_depth):
190 | temp_featuremap_dim = self.featuremap_dim + self.addrate
191 | layers.append(block(int(round(self.featuremap_dim)) * block.outchannel_ratio, int(round(temp_featuremap_dim)), 1))
192 | self.featuremap_dim = temp_featuremap_dim
193 | self.input_featuremap_dim = int(round(self.featuremap_dim)) * block.outchannel_ratio
194 |
195 | return nn.Sequential(*layers)
196 |
197 | def forward(self, x):
198 | if self.dataset == 'cifar10' or self.dataset == 'cifar100':
199 | x = self.conv1(x)
200 | x = self.bn1(x)
201 |
202 | x = self.layer1(x)
203 | x = self.layer2(x)
204 | x = self.layer3(x)
205 |
206 | x = self.bn_final(x)
207 | x = self.relu_final(x)
208 | x = self.avgpool(x)
209 | x = x.view(x.size(0), -1)
210 | x = self.fc(x)
211 |
212 | elif self.dataset == 'imagenet':
213 | x = self.conv1(x)
214 | x = self.bn1(x)
215 | x = self.relu(x)
216 | x = self.maxpool(x)
217 |
218 | x = self.layer1(x)
219 | x = self.layer2(x)
220 | x = self.layer3(x)
221 | x = self.layer4(x)
222 |
223 | x = self.bn_final(x)
224 | x = self.relu_final(x)
225 | x = self.avgpool(x)
226 | x = x.view(x.size(0), -1)
227 | x = self.fc(x)
228 |
229 | return x
230 |
--------------------------------------------------------------------------------
/resnet.py:
--------------------------------------------------------------------------------
1 | # Original code: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
2 |
3 | import torch.nn as nn
4 | import math
5 |
6 | def conv3x3(in_planes, out_planes, stride=1):
7 | "3x3 convolution with padding"
8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
9 | padding=1, bias=False)
10 |
11 |
12 | class BasicBlock(nn.Module):
13 | expansion = 1
14 |
15 | def __init__(self, inplanes, planes, stride=1, downsample=None):
16 | super(BasicBlock, self).__init__()
17 | self.conv1 = conv3x3(inplanes, planes, stride)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.conv2 = conv3x3(planes, planes)
20 | self.bn2 = nn.BatchNorm2d(planes)
21 | self.relu = nn.ReLU(inplace=True)
22 |
23 | self.downsample = downsample
24 | self.stride = stride
25 |
26 | def forward(self, x):
27 | residual = x
28 |
29 | out = self.conv1(x)
30 | out = self.bn1(out)
31 | out = self.relu(out)
32 |
33 | out = self.conv2(out)
34 | out = self.bn2(out)
35 |
36 | if self.downsample is not None:
37 | residual = self.downsample(x)
38 |
39 | out += residual
40 | out = self.relu(out)
41 |
42 | return out
43 |
44 |
45 | class Bottleneck(nn.Module):
46 | expansion = 4
47 |
48 | def __init__(self, inplanes, planes, stride=1, downsample=None):
49 | super(Bottleneck, self).__init__()
50 |
51 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
52 | self.bn1 = nn.BatchNorm2d(planes)
53 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
54 | self.bn2 = nn.BatchNorm2d(planes)
55 | self.conv3 = nn.Conv2d(planes, planes * Bottleneck.expansion, kernel_size=1, bias=False)
56 | self.bn3 = nn.BatchNorm2d(planes * Bottleneck.expansion)
57 | self.relu = nn.ReLU(inplace=True)
58 |
59 | self.downsample = downsample
60 | self.stride = stride
61 |
62 | def forward(self, x):
63 | residual = x
64 |
65 | out = self.conv1(x)
66 | out = self.bn1(out)
67 | out = self.relu(out)
68 |
69 | out = self.conv2(out)
70 | out = self.bn2(out)
71 | out = self.relu(out)
72 |
73 | out = self.conv3(out)
74 | out = self.bn3(out)
75 | if self.downsample is not None:
76 | residual = self.downsample(x)
77 |
78 | out += residual
79 | out = self.relu(out)
80 |
81 | return out
82 |
83 | class ResNet(nn.Module):
84 | def __init__(self, dataset, depth, num_classes, bottleneck=False):
85 | super(ResNet, self).__init__()
86 | self.dataset = dataset
87 | if self.dataset.startswith('cifar'):
88 | self.inplanes = 16
89 | print(bottleneck)
90 | if bottleneck == True:
91 | n = int((depth - 2) / 9)
92 | block = Bottleneck
93 | else:
94 | n = int((depth - 2) / 6)
95 | block = BasicBlock
96 |
97 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
98 | self.bn1 = nn.BatchNorm2d(self.inplanes)
99 | self.relu = nn.ReLU(inplace=True)
100 | self.layer1 = self._make_layer(block, 16, n)
101 | self.layer2 = self._make_layer(block, 32, n, stride=2)
102 | self.layer3 = self._make_layer(block, 64, n, stride=2)
103 | self.avgpool = nn.AvgPool2d(8)
104 | self.fc = nn.Linear(64 * block.expansion, num_classes)
105 |
106 | elif dataset == 'imagenet':
107 | blocks ={18: BasicBlock, 34: BasicBlock, 50: Bottleneck, 101: Bottleneck, 152: Bottleneck, 200: Bottleneck}
108 | layers ={18: [2, 2, 2, 2], 34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]}
109 | assert layers[depth], 'invalid detph for ResNet (depth should be one of 18, 34, 50, 101, 152, and 200)'
110 |
111 | self.inplanes = 64
112 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
113 | self.bn1 = nn.BatchNorm2d(64)
114 | self.relu = nn.ReLU(inplace=True)
115 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
116 | self.layer1 = self._make_layer(blocks[depth], 64, layers[depth][0])
117 | self.layer2 = self._make_layer(blocks[depth], 128, layers[depth][1], stride=2)
118 | self.layer3 = self._make_layer(blocks[depth], 256, layers[depth][2], stride=2)
119 | self.layer4 = self._make_layer(blocks[depth], 512, layers[depth][3], stride=2)
120 | self.avgpool = nn.AvgPool2d(7)
121 | self.fc = nn.Linear(512 * blocks[depth].expansion, num_classes)
122 |
123 | for m in self.modules():
124 | if isinstance(m, nn.Conv2d):
125 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
126 | m.weight.data.normal_(0, math.sqrt(2. / n))
127 | elif isinstance(m, nn.BatchNorm2d):
128 | m.weight.data.fill_(1)
129 | m.bias.data.zero_()
130 |
131 | def _make_layer(self, block, planes, blocks, stride=1):
132 | downsample = None
133 | if stride != 1 or self.inplanes != planes * block.expansion:
134 | downsample = nn.Sequential(
135 | nn.Conv2d(self.inplanes, planes * block.expansion,
136 | kernel_size=1, stride=stride, bias=False),
137 | nn.BatchNorm2d(planes * block.expansion),
138 | )
139 |
140 | layers = []
141 | layers.append(block(self.inplanes, planes, stride, downsample))
142 | self.inplanes = planes * block.expansion
143 | for i in range(1, blocks):
144 | layers.append(block(self.inplanes, planes))
145 |
146 | return nn.Sequential(*layers)
147 |
148 | def forward(self, x):
149 | if self.dataset == 'cifar10' or self.dataset == 'cifar100':
150 | x = self.conv1(x)
151 | x = self.bn1(x)
152 | x = self.relu(x)
153 |
154 | x = self.layer1(x)
155 | x = self.layer2(x)
156 | x = self.layer3(x)
157 |
158 | x = self.avgpool(x)
159 | x = x.view(x.size(0), -1)
160 | x = self.fc(x)
161 |
162 | elif self.dataset == 'imagenet':
163 | x = self.conv1(x)
164 | x = self.bn1(x)
165 | x = self.relu(x)
166 | x = self.maxpool(x)
167 |
168 | x = self.layer1(x)
169 | x = self.layer2(x)
170 | x = self.layer3(x)
171 | x = self.layer4(x)
172 |
173 | x = self.avgpool(x)
174 | x = x.view(x.size(0), -1)
175 | x = self.fc(x)
176 |
177 | return x
178 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # original code: https://github.com/dyhan0920/PyramidNet-PyTorch/blob/master/train.py
2 |
3 | import argparse
4 | import os
5 | import shutil
6 | import time
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.parallel
11 | import torch.backends.cudnn as cudnn
12 | import torch.optim
13 | import torch.utils.data
14 | import torch.utils.data.distributed
15 | import torchvision.transforms as transforms
16 | import torchvision.datasets as datasets
17 | import torchvision.models as models
18 | import resnet as RN
19 | import pyramidnet as PYRM
20 |
21 | import warnings
22 |
23 | warnings.filterwarnings("ignore")
24 |
25 | model_names = sorted(name for name in models.__dict__
26 | if name.islower() and not name.startswith("__")
27 | and callable(models.__dict__[name]))
28 |
29 | parser = argparse.ArgumentParser(description='Cutmix PyTorch CIFAR-10, CIFAR-100 and ImageNet-1k Test')
30 | parser.add_argument('--net_type', default='pyramidnet', type=str,
31 | help='networktype: resnet, and pyamidnet')
32 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
33 | help='number of data loading workers (default: 4)')
34 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
35 | help='number of total epochs to run')
36 | parser.add_argument('-b', '--batch_size', default=128, type=int,
37 | metavar='N', help='mini-batch size (default: 256)')
38 | parser.add_argument('--print-freq', '-p', default=1, type=int,
39 | metavar='N', help='print frequency (default: 10)')
40 | parser.add_argument('--depth', default=32, type=int,
41 | help='depth of the network (default: 32)')
42 | parser.add_argument('--no-bottleneck', dest='bottleneck', action='store_false',
43 | help='to use basicblock for CIFAR datasets (default: bottleneck)')
44 | parser.add_argument('--dataset', dest='dataset', default='imagenet', type=str,
45 | help='dataset (options: cifar10, cifar100, and imagenet)')
46 | parser.add_argument('--alpha', default=300, type=float,
47 | help='number of new channel increases per depth (default: 300)')
48 | parser.add_argument('--no-verbose', dest='verbose', action='store_false',
49 | help='to print the status at every iteration')
50 | parser.add_argument('--pretrained', default='/set/your/model/path', type=str, metavar='PATH')
51 |
52 | parser.set_defaults(bottleneck=True)
53 | parser.set_defaults(verbose=True)
54 |
55 | best_err1 = 100
56 | best_err5 = 100
57 |
58 |
59 | def main():
60 | global args, best_err1, best_err5
61 | args = parser.parse_args()
62 |
63 | if args.dataset.startswith('cifar'):
64 | normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
65 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
66 |
67 | transform_train = transforms.Compose([
68 | transforms.RandomCrop(32, padding=4),
69 | transforms.RandomHorizontalFlip(),
70 | transforms.ToTensor(),
71 | normalize,
72 | ])
73 |
74 | transform_test = transforms.Compose([
75 | transforms.ToTensor(),
76 | normalize
77 | ])
78 |
79 | if args.dataset == 'cifar100':
80 | val_loader = torch.utils.data.DataLoader(
81 | datasets.CIFAR100('../data', train=False, transform=transform_test),
82 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
83 | numberofclass = 100
84 | elif args.dataset == 'cifar10':
85 | val_loader = torch.utils.data.DataLoader(
86 | datasets.CIFAR10('../data', train=False, transform=transform_test),
87 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
88 | numberofclass = 10
89 | else:
90 | raise Exception('unknown dataset: {}'.format(args.dataset))
91 |
92 | elif args.dataset == 'imagenet':
93 |
94 | valdir = os.path.join('/home/data/ILSVRC/val')
95 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
96 | std=[0.229, 0.224, 0.225])
97 |
98 | val_loader = torch.utils.data.DataLoader(
99 | datasets.ImageFolder(valdir, transforms.Compose([
100 | transforms.Resize(256),
101 | transforms.CenterCrop(224),
102 | transforms.ToTensor(),
103 | normalize,
104 | ])),
105 | batch_size=args.batch_size, shuffle=False,
106 | num_workers=args.workers, pin_memory=True)
107 | numberofclass = 1000
108 |
109 | else:
110 | raise Exception('unknown dataset: {}'.format(args.dataset))
111 |
112 | print("=> creating model '{}'".format(args.net_type))
113 | if args.net_type == 'resnet':
114 | model = RN.ResNet(args.dataset, args.depth, numberofclass, args.bottleneck) # for ResNet
115 | elif args.net_type == 'pyramidnet':
116 | model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha, numberofclass,
117 | args.bottleneck)
118 | else:
119 | raise Exception('unknown network architecture: {}'.format(args.net_type))
120 |
121 | model = torch.nn.DataParallel(model).cuda()
122 |
123 | if os.path.isfile(args.pretrained):
124 | print("=> loading checkpoint '{}'".format(args.pretrained))
125 | checkpoint = torch.load(args.pretrained)
126 | model.load_state_dict(checkpoint['state_dict'])
127 | print("=> loaded checkpoint '{}'".format(args.pretrained))
128 | else:
129 | raise Exception("=> no checkpoint found at '{}'".format(args.pretrained))
130 |
131 | print(model)
132 | print('the number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
133 |
134 | # define loss function (criterion) and optimizer
135 | criterion = nn.CrossEntropyLoss().cuda()
136 |
137 | cudnn.benchmark = True
138 |
139 | # evaluate on validation set
140 | err1, err5, val_loss = validate(val_loader, model, criterion)
141 |
142 | print('Accuracy (top-1 and 5 error):', err1, err5)
143 |
144 |
145 | def validate(val_loader, model, criterion):
146 | batch_time = AverageMeter()
147 | losses = AverageMeter()
148 | top1 = AverageMeter()
149 | top5 = AverageMeter()
150 |
151 | # switch to evaluate mode
152 | model.eval()
153 |
154 | end = time.time()
155 | for i, (input, target) in enumerate(val_loader):
156 | target = target.cuda()
157 |
158 | output = model(input)
159 | loss = criterion(output, target)
160 |
161 | # measure accuracy and record loss
162 | err1, err5 = accuracy(output.data, target, topk=(1, 5))
163 |
164 | losses.update(loss.item(), input.size(0))
165 |
166 | top1.update(err1.item(), input.size(0))
167 | top5.update(err5.item(), input.size(0))
168 |
169 | # measure elapsed time
170 | batch_time.update(time.time() - end)
171 | end = time.time()
172 |
173 | if i % args.print_freq == 0 and args.verbose == True:
174 | print('Test (on val set): [{0}/{1}]\t'
175 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
176 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
177 | 'Top 1-err {top1.val:.4f} ({top1.avg:.4f})\t'
178 | 'Top 5-err {top5.val:.4f} ({top5.avg:.4f})'.format(
179 | i, len(val_loader), batch_time=batch_time, loss=losses,
180 | top1=top1, top5=top5))
181 |
182 | return top1.avg, top5.avg, losses.avg
183 |
184 |
185 | class AverageMeter(object):
186 | """Computes and stores the average and current value"""
187 |
188 | def __init__(self):
189 | self.reset()
190 |
191 | def reset(self):
192 | self.val = 0
193 | self.avg = 0
194 | self.sum = 0
195 | self.count = 0
196 |
197 | def update(self, val, n=1):
198 | self.val = val
199 | self.sum += val * n
200 | self.count += n
201 | self.avg = self.sum / self.count
202 |
203 |
204 | def accuracy(output, target, topk=(1,)):
205 | """Computes the precision@k for the specified values of k"""
206 | maxk = max(topk)
207 | batch_size = target.size(0)
208 |
209 | _, pred = output.topk(maxk, 1, True, True)
210 | pred = pred.t()
211 | correct = pred.eq(target.view(1, -1).expand_as(pred))
212 |
213 | res = []
214 | for k in topk:
215 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
216 | wrong_k = batch_size - correct_k
217 | res.append(wrong_k.mul_(100.0 / batch_size))
218 |
219 | return res
220 |
221 |
222 | if __name__ == '__main__':
223 | main()
224 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # original code: https://github.com/dyhan0920/PyramidNet-PyTorch/blob/master/train.py
2 |
3 | import argparse
4 | import os
5 | import shutil
6 | import time
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.parallel
11 | import torch.backends.cudnn as cudnn
12 | import torch.optim
13 | import torch.utils.data
14 | import torch.utils.data.distributed
15 | import torchvision.transforms as transforms
16 | import torchvision.datasets as datasets
17 | import torchvision.models as models
18 | import resnet as RN
19 | import pyramidnet as PYRM
20 | import utils
21 | import numpy as np
22 |
23 | import warnings
24 |
25 | warnings.filterwarnings("ignore")
26 |
27 | model_names = sorted(name for name in models.__dict__
28 | if name.islower() and not name.startswith("__")
29 | and callable(models.__dict__[name]))
30 |
31 | parser = argparse.ArgumentParser(description='Cutmix PyTorch CIFAR-10, CIFAR-100 and ImageNet-1k Training')
32 | parser.add_argument('--net_type', default='pyramidnet', type=str,
33 | help='networktype: resnet, and pyamidnet')
34 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
35 | help='number of data loading workers (default: 4)')
36 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
37 | help='number of total epochs to run')
38 | parser.add_argument('-b', '--batch_size', default=128, type=int,
39 | metavar='N', help='mini-batch size (default: 256)')
40 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
41 | metavar='LR', help='initial learning rate')
42 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
43 | help='momentum')
44 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
45 | metavar='W', help='weight decay (default: 1e-4)')
46 | parser.add_argument('--print-freq', '-p', default=1, type=int,
47 | metavar='N', help='print frequency (default: 10)')
48 | parser.add_argument('--depth', default=32, type=int,
49 | help='depth of the network (default: 32)')
50 | parser.add_argument('--no-bottleneck', dest='bottleneck', action='store_false',
51 | help='to use basicblock for CIFAR datasets (default: bottleneck)')
52 | parser.add_argument('--dataset', dest='dataset', default='imagenet', type=str,
53 | help='dataset (options: cifar10, cifar100, and imagenet)')
54 | parser.add_argument('--no-verbose', dest='verbose', action='store_false',
55 | help='to print the status at every iteration')
56 | parser.add_argument('--alpha', default=300, type=float,
57 | help='number of new channel increases per depth (default: 300)')
58 | parser.add_argument('--expname', default='TEST', type=str,
59 | help='name of experiment')
60 | parser.add_argument('--beta', default=0, type=float,
61 | help='hyperparameter beta')
62 | parser.add_argument('--cutmix_prob', default=0, type=float,
63 | help='cutmix probability')
64 |
65 | parser.set_defaults(bottleneck=True)
66 | parser.set_defaults(verbose=True)
67 |
68 | best_err1 = 100
69 | best_err5 = 100
70 |
71 |
72 | def main():
73 | global args, best_err1, best_err5
74 | args = parser.parse_args()
75 |
76 | if args.dataset.startswith('cifar'):
77 | normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
78 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
79 |
80 | transform_train = transforms.Compose([
81 | transforms.RandomCrop(32, padding=4),
82 | transforms.RandomHorizontalFlip(),
83 | transforms.ToTensor(),
84 | normalize,
85 | ])
86 |
87 | transform_test = transforms.Compose([
88 | transforms.ToTensor(),
89 | normalize
90 | ])
91 |
92 | if args.dataset == 'cifar100':
93 | train_loader = torch.utils.data.DataLoader(
94 | datasets.CIFAR100('../data', train=True, download=True, transform=transform_train),
95 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
96 | val_loader = torch.utils.data.DataLoader(
97 | datasets.CIFAR100('../data', train=False, transform=transform_test),
98 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
99 | numberofclass = 100
100 | elif args.dataset == 'cifar10':
101 | train_loader = torch.utils.data.DataLoader(
102 | datasets.CIFAR10('../data', train=True, download=True, transform=transform_train),
103 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
104 | val_loader = torch.utils.data.DataLoader(
105 | datasets.CIFAR10('../data', train=False, transform=transform_test),
106 | batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
107 | numberofclass = 10
108 | else:
109 | raise Exception('unknown dataset: {}'.format(args.dataset))
110 |
111 | elif args.dataset == 'imagenet':
112 | traindir = os.path.join('/home/data/ILSVRC/train')
113 | valdir = os.path.join('/home/data/ILSVRC/val')
114 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
115 | std=[0.229, 0.224, 0.225])
116 |
117 | jittering = utils.ColorJitter(brightness=0.4, contrast=0.4,
118 | saturation=0.4)
119 | lighting = utils.Lighting(alphastd=0.1,
120 | eigval=[0.2175, 0.0188, 0.0045],
121 | eigvec=[[-0.5675, 0.7192, 0.4009],
122 | [-0.5808, -0.0045, -0.8140],
123 | [-0.5836, -0.6948, 0.4203]])
124 |
125 | train_dataset = datasets.ImageFolder(
126 | traindir,
127 | transforms.Compose([
128 | transforms.RandomResizedCrop(224),
129 | transforms.RandomHorizontalFlip(),
130 | transforms.ToTensor(),
131 | jittering,
132 | lighting,
133 | normalize,
134 | ]))
135 |
136 | train_sampler = None
137 |
138 | train_loader = torch.utils.data.DataLoader(
139 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
140 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
141 |
142 | val_loader = torch.utils.data.DataLoader(
143 | datasets.ImageFolder(valdir, transforms.Compose([
144 | transforms.Resize(256),
145 | transforms.CenterCrop(224),
146 | transforms.ToTensor(),
147 | normalize,
148 | ])),
149 | batch_size=args.batch_size, shuffle=False,
150 | num_workers=args.workers, pin_memory=True)
151 | numberofclass = 1000
152 |
153 | else:
154 | raise Exception('unknown dataset: {}'.format(args.dataset))
155 |
156 | print("=> creating model '{}'".format(args.net_type))
157 | if args.net_type == 'resnet':
158 | model = RN.ResNet(args.dataset, args.depth, numberofclass, args.bottleneck) # for ResNet
159 | elif args.net_type == 'pyramidnet':
160 | model = PYRM.PyramidNet(args.dataset, args.depth, args.alpha, numberofclass,
161 | args.bottleneck)
162 | else:
163 | raise Exception('unknown network architecture: {}'.format(args.net_type))
164 |
165 | model = torch.nn.DataParallel(model).cuda()
166 |
167 | print(model)
168 | print('the number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
169 |
170 | # define loss function (criterion) and optimizer
171 | criterion = nn.CrossEntropyLoss().cuda()
172 |
173 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
174 | momentum=args.momentum,
175 | weight_decay=args.weight_decay, nesterov=True)
176 |
177 | cudnn.benchmark = True
178 |
179 | for epoch in range(0, args.epochs):
180 |
181 | adjust_learning_rate(optimizer, epoch)
182 |
183 | # train for one epoch
184 | train_loss = train(train_loader, model, criterion, optimizer, epoch)
185 |
186 | # evaluate on validation set
187 | err1, err5, val_loss = validate(val_loader, model, criterion, epoch)
188 |
189 | # remember best prec@1 and save checkpoint
190 | is_best = err1 <= best_err1
191 | best_err1 = min(err1, best_err1)
192 | if is_best:
193 | best_err5 = err5
194 |
195 | print('Current best accuracy (top-1 and 5 error):', best_err1, best_err5)
196 | save_checkpoint({
197 | 'epoch': epoch,
198 | 'arch': args.net_type,
199 | 'state_dict': model.state_dict(),
200 | 'best_err1': best_err1,
201 | 'best_err5': best_err5,
202 | 'optimizer': optimizer.state_dict(),
203 | }, is_best)
204 |
205 | print('Best accuracy (top-1 and 5 error):', best_err1, best_err5)
206 |
207 |
208 | def train(train_loader, model, criterion, optimizer, epoch):
209 | batch_time = AverageMeter()
210 | data_time = AverageMeter()
211 | losses = AverageMeter()
212 | top1 = AverageMeter()
213 | top5 = AverageMeter()
214 |
215 | # switch to train mode
216 | model.train()
217 |
218 | end = time.time()
219 | current_LR = get_learning_rate(optimizer)[0]
220 | for i, (input, target) in enumerate(train_loader):
221 | # measure data loading time
222 | data_time.update(time.time() - end)
223 |
224 | input = input.cuda()
225 | target = target.cuda()
226 |
227 | r = np.random.rand(1)
228 | if args.beta > 0 and r < args.cutmix_prob:
229 | # generate mixed sample
230 | lam = np.random.beta(args.beta, args.beta)
231 | rand_index = torch.randperm(input.size()[0]).cuda()
232 | target_a = target
233 | target_b = target[rand_index]
234 | bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam)
235 | input[:, :, bbx1:bbx2, bby1:bby2] = input[rand_index, :, bbx1:bbx2, bby1:bby2]
236 | # adjust lambda to exactly match pixel ratio
237 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2]))
238 | # compute output
239 | output = model(input)
240 | loss = criterion(output, target_a) * lam + criterion(output, target_b) * (1. - lam)
241 | else:
242 | # compute output
243 | output = model(input)
244 | loss = criterion(output, target)
245 |
246 | # measure accuracy and record loss
247 | err1, err5 = accuracy(output.data, target, topk=(1, 5))
248 |
249 | losses.update(loss.item(), input.size(0))
250 | top1.update(err1.item(), input.size(0))
251 | top5.update(err5.item(), input.size(0))
252 |
253 | # compute gradient and do SGD step
254 | optimizer.zero_grad()
255 | loss.backward()
256 | optimizer.step()
257 |
258 | # measure elapsed time
259 | batch_time.update(time.time() - end)
260 | end = time.time()
261 |
262 | if i % args.print_freq == 0 and args.verbose == True:
263 | print('Epoch: [{0}/{1}][{2}/{3}]\t'
264 | 'LR: {LR:.6f}\t'
265 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
266 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
267 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
268 | 'Top 1-err {top1.val:.4f} ({top1.avg:.4f})\t'
269 | 'Top 5-err {top5.val:.4f} ({top5.avg:.4f})'.format(
270 | epoch, args.epochs, i, len(train_loader), LR=current_LR, batch_time=batch_time,
271 | data_time=data_time, loss=losses, top1=top1, top5=top5))
272 |
273 | print('* Epoch: [{0}/{1}]\t Top 1-err {top1.avg:.3f} Top 5-err {top5.avg:.3f}\t Train Loss {loss.avg:.3f}'.format(
274 | epoch, args.epochs, top1=top1, top5=top5, loss=losses))
275 |
276 | return losses.avg
277 |
278 |
279 | def rand_bbox(size, lam):
280 | W = size[2]
281 | H = size[3]
282 | cut_rat = np.sqrt(1. - lam)
283 | cut_w = np.int(W * cut_rat)
284 | cut_h = np.int(H * cut_rat)
285 |
286 | # uniform
287 | cx = np.random.randint(W)
288 | cy = np.random.randint(H)
289 |
290 | bbx1 = np.clip(cx - cut_w // 2, 0, W)
291 | bby1 = np.clip(cy - cut_h // 2, 0, H)
292 | bbx2 = np.clip(cx + cut_w // 2, 0, W)
293 | bby2 = np.clip(cy + cut_h // 2, 0, H)
294 |
295 | return bbx1, bby1, bbx2, bby2
296 |
297 |
298 | def validate(val_loader, model, criterion, epoch):
299 | batch_time = AverageMeter()
300 | losses = AverageMeter()
301 | top1 = AverageMeter()
302 | top5 = AverageMeter()
303 |
304 | # switch to evaluate mode
305 | model.eval()
306 |
307 | end = time.time()
308 | for i, (input, target) in enumerate(val_loader):
309 | target = target.cuda()
310 |
311 | output = model(input)
312 | loss = criterion(output, target)
313 |
314 | # measure accuracy and record loss
315 | err1, err5 = accuracy(output.data, target, topk=(1, 5))
316 |
317 | losses.update(loss.item(), input.size(0))
318 |
319 | top1.update(err1.item(), input.size(0))
320 | top5.update(err5.item(), input.size(0))
321 |
322 | # measure elapsed time
323 | batch_time.update(time.time() - end)
324 | end = time.time()
325 |
326 | if i % args.print_freq == 0 and args.verbose == True:
327 | print('Test (on val set): [{0}/{1}][{2}/{3}]\t'
328 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
329 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
330 | 'Top 1-err {top1.val:.4f} ({top1.avg:.4f})\t'
331 | 'Top 5-err {top5.val:.4f} ({top5.avg:.4f})'.format(
332 | epoch, args.epochs, i, len(val_loader), batch_time=batch_time, loss=losses,
333 | top1=top1, top5=top5))
334 |
335 | print('* Epoch: [{0}/{1}]\t Top 1-err {top1.avg:.3f} Top 5-err {top5.avg:.3f}\t Test Loss {loss.avg:.3f}'.format(
336 | epoch, args.epochs, top1=top1, top5=top5, loss=losses))
337 | return top1.avg, top5.avg, losses.avg
338 |
339 |
340 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
341 | directory = "runs/%s/" % (args.expname)
342 | if not os.path.exists(directory):
343 | os.makedirs(directory)
344 | filename = directory + filename
345 | torch.save(state, filename)
346 | if is_best:
347 | shutil.copyfile(filename, 'runs/%s/' % (args.expname) + 'model_best.pth.tar')
348 |
349 |
350 | class AverageMeter(object):
351 | """Computes and stores the average and current value"""
352 |
353 | def __init__(self):
354 | self.reset()
355 |
356 | def reset(self):
357 | self.val = 0
358 | self.avg = 0
359 | self.sum = 0
360 | self.count = 0
361 |
362 | def update(self, val, n=1):
363 | self.val = val
364 | self.sum += val * n
365 | self.count += n
366 | self.avg = self.sum / self.count
367 |
368 |
369 | def adjust_learning_rate(optimizer, epoch):
370 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
371 | if args.dataset.startswith('cifar'):
372 | lr = args.lr * (0.1 ** (epoch // (args.epochs * 0.5))) * (0.1 ** (epoch // (args.epochs * 0.75)))
373 | elif args.dataset == ('imagenet'):
374 | if args.epochs == 300:
375 | lr = args.lr * (0.1 ** (epoch // 75))
376 | else:
377 | lr = args.lr * (0.1 ** (epoch // 30))
378 |
379 | for param_group in optimizer.param_groups:
380 | param_group['lr'] = lr
381 |
382 |
383 | def get_learning_rate(optimizer):
384 | lr = []
385 | for param_group in optimizer.param_groups:
386 | lr += [param_group['lr']]
387 | return lr
388 |
389 |
390 | def accuracy(output, target, topk=(1,)):
391 | """Computes the precision@k for the specified values of k"""
392 | maxk = max(topk)
393 | batch_size = target.size(0)
394 |
395 | _, pred = output.topk(maxk, 1, True, True)
396 | pred = pred.t()
397 | correct = pred.eq(target.view(1, -1).expand_as(pred))
398 |
399 | res = []
400 | for k in topk:
401 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
402 | wrong_k = batch_size - correct_k
403 | res.append(wrong_k.mul_(100.0 / batch_size))
404 |
405 | return res
406 |
407 |
408 | if __name__ == '__main__':
409 | main()
410 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # original code: https://github.com/eladhoffer/convNet.pytorch/blob/master/preprocess.py
2 |
3 | import torch
4 | import random
5 |
6 | __all__ = ["Compose", "Lighting", "ColorJitter"]
7 |
8 |
9 | class Compose(object):
10 | """Composes several transforms together.
11 |
12 | Args:
13 | transforms (list of ``Transform`` objects): list of transforms to compose.
14 |
15 | Example:
16 | >>> transforms.Compose([
17 | >>> transforms.CenterCrop(10),
18 | >>> transforms.ToTensor(),
19 | >>> ])
20 | """
21 |
22 | def __init__(self, transforms):
23 | self.transforms = transforms
24 |
25 | def __call__(self, img):
26 | for t in self.transforms:
27 | img = t(img)
28 | return img
29 |
30 | def __repr__(self):
31 | format_string = self.__class__.__name__ + '('
32 | for t in self.transforms:
33 | format_string += '\n'
34 | format_string += ' {0}'.format(t)
35 | format_string += '\n)'
36 | return format_string
37 |
38 |
39 | class Lighting(object):
40 | """Lighting noise(AlexNet - style PCA - based noise)"""
41 |
42 | def __init__(self, alphastd, eigval, eigvec):
43 | self.alphastd = alphastd
44 | self.eigval = torch.Tensor(eigval)
45 | self.eigvec = torch.Tensor(eigvec)
46 |
47 | def __call__(self, img):
48 | if self.alphastd == 0:
49 | return img
50 |
51 | alpha = img.new().resize_(3).normal_(0, self.alphastd)
52 | rgb = self.eigvec.type_as(img).clone() \
53 | .mul(alpha.view(1, 3).expand(3, 3)) \
54 | .mul(self.eigval.view(1, 3).expand(3, 3)) \
55 | .sum(1).squeeze()
56 |
57 | return img.add(rgb.view(3, 1, 1).expand_as(img))
58 |
59 |
60 | class Grayscale(object):
61 |
62 | def __call__(self, img):
63 | gs = img.clone()
64 | gs[0].mul_(0.299).add_(0.587, gs[1]).add_(0.114, gs[2])
65 | gs[1].copy_(gs[0])
66 | gs[2].copy_(gs[0])
67 | return gs
68 |
69 |
70 | class Saturation(object):
71 |
72 | def __init__(self, var):
73 | self.var = var
74 |
75 | def __call__(self, img):
76 | gs = Grayscale()(img)
77 | alpha = random.uniform(-self.var, self.var)
78 | return img.lerp(gs, alpha)
79 |
80 |
81 | class Brightness(object):
82 |
83 | def __init__(self, var):
84 | self.var = var
85 |
86 | def __call__(self, img):
87 | gs = img.new().resize_as_(img).zero_()
88 | alpha = random.uniform(-self.var, self.var)
89 | return img.lerp(gs, alpha)
90 |
91 |
92 | class Contrast(object):
93 |
94 | def __init__(self, var):
95 | self.var = var
96 |
97 | def __call__(self, img):
98 | gs = Grayscale()(img)
99 | gs.fill_(gs.mean())
100 | alpha = random.uniform(-self.var, self.var)
101 | return img.lerp(gs, alpha)
102 |
103 |
104 | class ColorJitter(object):
105 |
106 | def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4):
107 | self.brightness = brightness
108 | self.contrast = contrast
109 | self.saturation = saturation
110 |
111 | def __call__(self, img):
112 | self.transforms = []
113 | if self.brightness != 0:
114 | self.transforms.append(Brightness(self.brightness))
115 | if self.contrast != 0:
116 | self.transforms.append(Contrast(self.contrast))
117 | if self.saturation != 0:
118 | self.transforms.append(Saturation(self.saturation))
119 |
120 | random.shuffle(self.transforms)
121 | transform = Compose(self.transforms)
122 | # print(transform)
123 | return transform(img)
124 |
--------------------------------------------------------------------------------