├── .gitignore ├── figures └── octave_conv.png ├── models ├── __init__.py └── octave_resnet.py ├── README.md ├── LICENSE └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Pycharm editor settings 2 | .idea 3 | 4 | # macOS 5 | .DS_Store 6 | 7 | # log 8 | logs/ 9 | -------------------------------------------------------------------------------- /figures/octave_conv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vivym/OctaveConv.pytorch/HEAD/figures/octave_conv.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152 2 | from .octave_resnet import octave_resnet50, octave_resnet101, octave_resnet152 3 | 4 | __all__ = [ 5 | "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", 6 | "octave_resnet50", "octave_resnet101", "octave_resnet152", 7 | ] 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OctaveConv.pytorch 2 | A Pytorch Implementation for the paper [Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution](https://arxiv.org/abs/1904.05049). 3 | ![](figures/octave_conv.png) 4 | 5 | 6 | ## Usage 7 | ```python 8 | from models import octave_resnet50 9 | 10 | model = octave_resnet50(num_classes=10) 11 | ``` 12 | ## Reference 13 | Inspired by the MXNet implementation [here](https://github.com/terrychenism/OctaveConv). 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Viv 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /models/octave_resnet.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class OctConv(nn.Module): 8 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 9 | alpha_in=0.25, alpha_out=0.25, type='normal'): 10 | super(OctConv, self).__init__() 11 | self.kernel_size = kernel_size 12 | self.stride = stride 13 | self.type = type 14 | hf_ch_in = int(in_channels * (1 - alpha_in)) 15 | hf_ch_out = int(out_channels * (1 - alpha_out)) 16 | lf_ch_in = in_channels - hf_ch_in 17 | lf_ch_out = out_channels - hf_ch_out 18 | 19 | if type == 'first': 20 | if stride == 2: 21 | self.downsample = nn.AvgPool2d(kernel_size=2, stride=stride) 22 | self.convh = nn.Conv2d( 23 | in_channels, hf_ch_out, 24 | kernel_size=kernel_size, stride=1, padding=padding, 25 | ) 26 | self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2) 27 | self.convl = nn.Conv2d( 28 | in_channels, lf_ch_out, 29 | kernel_size=kernel_size, stride=1, padding=padding, 30 | ) 31 | elif type == 'last': 32 | if stride == 2: 33 | self.downsample = nn.AvgPool2d(kernel_size=2, stride=stride) 34 | self.convh = nn.Conv2d(hf_ch_in, out_channels, kernel_size=kernel_size, padding=padding) 35 | self.convl = nn.Conv2d(lf_ch_in, out_channels, kernel_size=kernel_size, padding=padding) 36 | self.upsample = partial(F.interpolate, scale_factor=2, mode="nearest") 37 | else: 38 | if stride == 2: 39 | self.downsample = nn.AvgPool2d(kernel_size=2, stride=stride) 40 | 41 | self.L2L = nn.Conv2d( 42 | lf_ch_in, lf_ch_out, 43 | kernel_size=kernel_size, stride=1, padding=padding 44 | ) 45 | self.L2H = nn.Conv2d( 46 | lf_ch_in, hf_ch_out, 47 | kernel_size=kernel_size, stride=1, padding=padding 48 | ) 49 | self.H2L = nn.Conv2d( 50 | hf_ch_in, lf_ch_out, 51 | kernel_size=kernel_size, stride=1, padding=padding 52 | ) 53 | self.H2H = nn.Conv2d( 54 | hf_ch_in, hf_ch_out, 55 | kernel_size=kernel_size, stride=1, padding=padding 56 | ) 57 | self.upsample = partial(F.interpolate, scale_factor=2, mode="nearest") 58 | self.avg_pool = partial(F.avg_pool2d, kernel_size=2, stride=2) 59 | 60 | def forward(self, x): 61 | if self.type == 'first': 62 | if self.stride == 2: 63 | x = self.downsample(x) 64 | 65 | hf = self.convh(x) 66 | lf = self.avg_pool(x) 67 | lf = self.convl(lf) 68 | 69 | return hf, lf 70 | elif self.type == 'last': 71 | hf, lf = x 72 | if self.stride == 2: 73 | hf = self.downsample(hf) 74 | return self.convh(hf) + self.convl(lf) 75 | else: 76 | return self.convh(hf) + self.convl(self.upsample(lf)) 77 | else: 78 | hf, lf = x 79 | if self.stride == 2: 80 | hf = self.downsample(hf) 81 | return self.H2H(hf) + self.L2H(lf), \ 82 | self.L2L(F.avg_pool2d(lf, kernel_size=2, stride=2)) + self.H2L(self.avg_pool(hf)) 83 | else: 84 | return self.H2H(hf) + self.upsample(self.L2H(lf)), self.L2L(lf) + self.H2L(self.avg_pool(hf)) 85 | 86 | 87 | def norm_conv3x3(in_planes, out_planes, stride=1, type=None): 88 | """3x3 convolution with padding""" 89 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 90 | padding=1, bias=False) 91 | 92 | 93 | def norm_conv1x1(in_planes, out_planes, stride=1, type=None): 94 | """1x1 convolution""" 95 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 96 | 97 | 98 | def oct_conv3x3(in_planes, out_planes, stride=1, type='normal'): 99 | """3x3 convolution with padding""" 100 | return OctConv(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, type=type) 101 | 102 | 103 | def oct_conv1x1(in_planes, out_planes, stride=1, type='normal'): 104 | """1x1 convolution""" 105 | return OctConv(in_planes, out_planes, kernel_size=1, stride=stride, type=type) 106 | 107 | 108 | class _BatchNorm2d(nn.Module): 109 | def __init__(self, num_features, alpha_in=0.25, alpha_out=0.25, eps=1e-5, momentum=0.1, affine=True, 110 | track_running_stats=True): 111 | super(_BatchNorm2d, self).__init__() 112 | hf_ch = int(num_features * (1 - alpha_in)) 113 | lf_ch = num_features - hf_ch 114 | self.bnh = nn.BatchNorm2d(hf_ch) 115 | self.bnl = nn.BatchNorm2d(lf_ch) 116 | 117 | def forward(self, x): 118 | hf, lf = x 119 | return self.bnh(hf), self.bnl(lf) 120 | 121 | 122 | class _ReLU(nn.ReLU): 123 | def forward(self, x): 124 | hf, lf = x 125 | hf = super(_ReLU, self).forward(hf) 126 | lf = super(_ReLU, self).forward(lf) 127 | return hf, lf 128 | 129 | 130 | class BasicBlock(nn.Module): 131 | expansion = 1 132 | 133 | def __init__(self, inplanes, planes, stride=1, downsample=None, type="normal", oct_conv_on=True): 134 | super(BasicBlock, self).__init__() 135 | conv3x3 = oct_conv3x3 if oct_conv_on else norm_conv3x3 136 | norm_func = _BatchNorm2d if oct_conv_on else nn.BatchNorm2d 137 | act_func = _ReLU if oct_conv_on else nn.ReLU 138 | 139 | 140 | self.conv1 = conv3x3(inplanes, planes, type="first" if type == "first" else "normal") 141 | self.bn1 = norm_func(planes) 142 | self.relu1 = act_func(inplace=True) 143 | self.conv2 = conv3x3(planes, planes, stride, type="last" if type == "last" else "normal") 144 | if type == "last": 145 | norm_func = nn.BatchNorm2d 146 | act_func = nn.ReLU 147 | self.bn2 = norm_func(planes) 148 | self.relu2 = act_func(inplace=True) 149 | self.downsample = downsample 150 | self.stride = stride 151 | 152 | def forward(self, x): 153 | identity = x 154 | 155 | out = self.conv1(x) 156 | out = self.bn1(out) 157 | out = self.relu1(out) 158 | 159 | out = self.conv2(out) 160 | out = self.bn2(out) 161 | 162 | if self.downsample is not None: 163 | identity = self.downsample(x) 164 | 165 | if isinstance(out, (tuple, list)): 166 | assert len(out) == len(identity) and len(out) == 2 167 | out = (out[0] + identity[0], out[1] + identity[1]) 168 | else: 169 | out += identity 170 | 171 | out = self.relu2(out) 172 | 173 | return out 174 | 175 | 176 | class Bottleneck(nn.Module): 177 | expansion = 4 178 | 179 | def __init__(self, inplanes, planes, stride=1, downsample=None, type="normal", oct_conv_on=True): 180 | super(Bottleneck, self).__init__() 181 | conv1x1 = oct_conv1x1 if oct_conv_on else norm_conv1x1 182 | conv3x3 = oct_conv3x3 if oct_conv_on else norm_conv3x3 183 | norm_func = _BatchNorm2d if oct_conv_on else nn.BatchNorm2d 184 | act_func = _ReLU if oct_conv_on else nn.ReLU 185 | 186 | self.conv1 = conv1x1(inplanes, planes, type="first" if type == "first" else "normal") 187 | self.bn1 = norm_func(planes) 188 | self.relu1 = act_func(inplace=True) 189 | self.conv2 = conv3x3(planes, planes, stride, type="last" if type == "last" else "normal") 190 | if type == "last": 191 | conv1x1 = norm_conv1x1 192 | norm_func = nn.BatchNorm2d 193 | act_func = nn.ReLU 194 | self.bn2 = norm_func(planes) 195 | self.relu2 = act_func(inplace=True) 196 | self.conv3 = conv1x1(planes, planes * self.expansion) 197 | self.bn3 = norm_func(planes * self.expansion) 198 | self.relu3 = act_func(inplace=True) 199 | self.downsample = downsample 200 | self.stride = stride 201 | 202 | def forward(self, x): 203 | identity = x 204 | 205 | out = self.conv1(x) 206 | out = self.bn1(out) 207 | out = self.relu1(out) 208 | 209 | out = self.conv2(out) 210 | out = self.bn2(out) 211 | out = self.relu2(out) 212 | 213 | out = self.conv3(out) 214 | out = self.bn3(out) 215 | 216 | if self.downsample is not None: 217 | identity = self.downsample(x) 218 | 219 | if isinstance(out, (tuple, list)): 220 | assert len(out) == len(identity) and len(out) == 2 221 | out = (out[0] + identity[0], out[1] + identity[1]) 222 | else: 223 | out += identity 224 | out = self.relu3(out) 225 | 226 | return out 227 | 228 | 229 | class ResNet(nn.Module): 230 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 231 | super(ResNet, self).__init__() 232 | self.inplanes = 64 233 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 234 | bias=False) 235 | self.bn1 = nn.BatchNorm2d(64) 236 | self.relu = nn.ReLU(inplace=True) 237 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 238 | self.layer1 = self._make_layer(block, 64, layers[0], type="first") 239 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 240 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 241 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, type="last") 242 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 243 | self.fc = nn.Linear(512 * block.expansion, num_classes) 244 | 245 | for m in self.modules(): 246 | if isinstance(m, nn.Conv2d): 247 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 248 | elif isinstance(m, nn.BatchNorm2d): 249 | nn.init.constant_(m.weight, 1) 250 | nn.init.constant_(m.bias, 0) 251 | 252 | # Zero-initialize the last BN in each residual branch, 253 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 254 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 255 | if zero_init_residual: 256 | for m in self.modules(): 257 | if isinstance(m, Bottleneck): 258 | nn.init.constant_(m.bn3.weight, 0) 259 | 260 | def _make_layer(self, block, planes, blocks, stride=1, type="normal"): 261 | downsample = None 262 | if stride != 1 or self.inplanes != planes * block.expansion or type=='first': 263 | norm_func = nn.BatchNorm2d if type == "last" else _BatchNorm2d 264 | downsample = nn.Sequential( 265 | oct_conv1x1(self.inplanes, planes * block.expansion, stride, type=type), 266 | norm_func(planes * block.expansion), 267 | ) 268 | 269 | layers = [] 270 | layers.append(block(self.inplanes, planes, stride, downsample, type=type)) 271 | self.inplanes = planes * block.expansion 272 | for _ in range(1, blocks): 273 | layers.append(block(self.inplanes, planes, oct_conv_on=type != "last")) 274 | 275 | return nn.Sequential(*layers) 276 | 277 | def forward(self, x): 278 | x = self.conv1(x) 279 | x = self.bn1(x) 280 | x = self.relu(x) 281 | x = self.maxpool(x) 282 | 283 | x = self.layer1(x) 284 | x = self.layer2(x) 285 | x = self.layer3(x) 286 | x = self.layer4(x) 287 | 288 | x = self.avgpool(x) 289 | x = x.view(x.size(0), -1) 290 | x = self.fc(x) 291 | 292 | return x 293 | 294 | 295 | def octave_resnet18(pretrained=False, **kwargs): 296 | """Constructs a ResNet-18 model. 297 | Args: 298 | pretrained (bool): If True, returns a model pre-trained on ImageNet 299 | """ 300 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 301 | return model 302 | 303 | 304 | def octave_resnet34(pretrained=False, **kwargs): 305 | """Constructs a ResNet-34 model. 306 | Args: 307 | pretrained (bool): If True, returns a model pre-trained on ImageNet 308 | """ 309 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 310 | return model 311 | 312 | 313 | def octave_resnet50(pretrained=False, **kwargs): 314 | """Constructs a ResNet-50 model. 315 | 316 | Args: 317 | pretrained (bool): If True, returns a model pre-trained on ImageNet 318 | """ 319 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 320 | return model 321 | 322 | 323 | def octave_resnet101(pretrained=False, **kwargs): 324 | """Constructs a ResNet-101 model. 325 | 326 | Args: 327 | pretrained (bool): If True, returns a model pre-trained on ImageNet 328 | """ 329 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 330 | return model 331 | 332 | 333 | def octave_resnet152(pretrained=False, **kwargs): 334 | """Constructs a ResNet-152 model. 335 | 336 | Args: 337 | pretrained (bool): If True, returns a model pre-trained on ImageNet 338 | """ 339 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 340 | return model 341 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.distributed as dist 13 | import torch.optim 14 | import torch.multiprocessing as mp 15 | import torch.utils.data 16 | import torch.utils.data.distributed 17 | import torchvision.transforms as transforms 18 | import torchvision.datasets as datasets 19 | 20 | import models 21 | 22 | model_names = sorted(name for name in models.__dict__ 23 | if name.islower() and not name.startswith("__") 24 | and callable(models.__dict__[name])) 25 | 26 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 27 | parser.add_argument('data', metavar='DIR', 28 | help='path to dataset') 29 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 30 | choices=model_names, 31 | help='model architecture: ' + 32 | ' | '.join(model_names) + 33 | ' (default: resnet18)') 34 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 35 | help='number of data loading workers (default: 4)') 36 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 37 | help='number of total epochs to run') 38 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 39 | help='manual epoch number (useful on restarts)') 40 | parser.add_argument('-b', '--batch-size', default=256, type=int, 41 | metavar='N', 42 | help='mini-batch size (default: 256), this is the total ' 43 | 'batch size of all GPUs on the current node when ' 44 | 'using Data Parallel or Distributed Data Parallel') 45 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 46 | metavar='LR', help='initial learning rate', dest='lr') 47 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 48 | help='momentum') 49 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 50 | metavar='W', help='weight decay (default: 1e-4)', 51 | dest='weight_decay') 52 | parser.add_argument('-p', '--print-freq', default=10, type=int, 53 | metavar='N', help='print frequency (default: 10)') 54 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 55 | help='path to latest checkpoint (default: none)') 56 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 57 | help='evaluate model on validation set') 58 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 59 | help='use pre-trained model') 60 | parser.add_argument('--world-size', default=-1, type=int, 61 | help='number of nodes for distributed training') 62 | parser.add_argument('--rank', default=-1, type=int, 63 | help='node rank for distributed training') 64 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 65 | help='url used to set up distributed training') 66 | parser.add_argument('--dist-backend', default='nccl', type=str, 67 | help='distributed backend') 68 | parser.add_argument('--seed', default=None, type=int, 69 | help='seed for initializing training. ') 70 | parser.add_argument('--gpu', default=None, type=int, 71 | help='GPU id to use.') 72 | parser.add_argument('--multiprocessing-distributed', action='store_true', 73 | help='Use multi-processing distributed training to launch ' 74 | 'N processes per node, which has N GPUs. This is the ' 75 | 'fastest way to use PyTorch for either single node or ' 76 | 'multi node data parallel training') 77 | 78 | best_acc1 = 0 79 | 80 | 81 | def main(): 82 | args = parser.parse_args() 83 | 84 | if args.seed is not None: 85 | random.seed(args.seed) 86 | torch.manual_seed(args.seed) 87 | cudnn.deterministic = True 88 | warnings.warn('You have chosen to seed training. ' 89 | 'This will turn on the CUDNN deterministic setting, ' 90 | 'which can slow down your training considerably! ' 91 | 'You may see unexpected behavior when restarting ' 92 | 'from checkpoints.') 93 | 94 | if args.gpu is not None: 95 | warnings.warn('You have chosen a specific GPU. This will completely ' 96 | 'disable data parallelism.') 97 | 98 | if args.dist_url == "env://" and args.world_size == -1: 99 | args.world_size = int(os.environ["WORLD_SIZE"]) 100 | 101 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 102 | 103 | ngpus_per_node = torch.cuda.device_count() 104 | if args.multiprocessing_distributed: 105 | # Since we have ngpus_per_node processes per node, the total world_size 106 | # needs to be adjusted accordingly 107 | args.world_size = ngpus_per_node * args.world_size 108 | # Use torch.multiprocessing.spawn to launch distributed processes: the 109 | # main_worker process function 110 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 111 | else: 112 | # Simply call main_worker function 113 | main_worker(args.gpu, ngpus_per_node, args) 114 | 115 | 116 | def main_worker(gpu, ngpus_per_node, args): 117 | global best_acc1 118 | args.gpu = gpu 119 | 120 | if args.gpu is not None: 121 | print("Use GPU: {} for training".format(args.gpu)) 122 | 123 | if args.distributed: 124 | if args.dist_url == "env://" and args.rank == -1: 125 | args.rank = int(os.environ["RANK"]) 126 | if args.multiprocessing_distributed: 127 | # For multiprocessing distributed training, rank needs to be the 128 | # global rank among all the processes 129 | args.rank = args.rank * ngpus_per_node + gpu 130 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 131 | world_size=args.world_size, rank=args.rank) 132 | # create model 133 | if args.pretrained: 134 | print("=> using pre-trained model '{}'".format(args.arch)) 135 | model = models.__dict__[args.arch](pretrained=True) 136 | else: 137 | print("=> creating model '{}'".format(args.arch)) 138 | model = models.__dict__[args.arch]() 139 | 140 | if args.distributed: 141 | # For multiprocessing distributed, DistributedDataParallel constructor 142 | # should always set the single device scope, otherwise, 143 | # DistributedDataParallel will use all available devices. 144 | if args.gpu is not None: 145 | torch.cuda.set_device(args.gpu) 146 | model.cuda(args.gpu) 147 | # When using a single GPU per process and per 148 | # DistributedDataParallel, we need to divide the batch size 149 | # ourselves based on the total number of GPUs we have 150 | args.batch_size = int(args.batch_size / ngpus_per_node) 151 | args.workers = int(args.workers / ngpus_per_node) 152 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 153 | else: 154 | model.cuda() 155 | # DistributedDataParallel will divide and allocate batch_size to all 156 | # available GPUs if device_ids are not set 157 | model = torch.nn.parallel.DistributedDataParallel(model) 158 | elif args.gpu is not None: 159 | torch.cuda.set_device(args.gpu) 160 | model = model.cuda(args.gpu) 161 | else: 162 | # DataParallel will divide and allocate batch_size to all available GPUs 163 | if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): 164 | model.features = torch.nn.DataParallel(model.features) 165 | model.cuda() 166 | else: 167 | model = torch.nn.DataParallel(model).cuda() 168 | 169 | # define loss function (criterion) and optimizer 170 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 171 | 172 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 173 | momentum=args.momentum, 174 | weight_decay=args.weight_decay) 175 | 176 | # optionally resume from a checkpoint 177 | if args.resume: 178 | if os.path.isfile(args.resume): 179 | print("=> loading checkpoint '{}'".format(args.resume)) 180 | checkpoint = torch.load(args.resume) 181 | args.start_epoch = checkpoint['epoch'] 182 | best_acc1 = checkpoint['best_acc1'] 183 | if args.gpu is not None: 184 | # best_acc1 may be from a checkpoint from a different GPU 185 | best_acc1 = best_acc1.to(args.gpu) 186 | model.load_state_dict(checkpoint['state_dict']) 187 | optimizer.load_state_dict(checkpoint['optimizer']) 188 | print("=> loaded checkpoint '{}' (epoch {})" 189 | .format(args.resume, checkpoint['epoch'])) 190 | else: 191 | print("=> no checkpoint found at '{}'".format(args.resume)) 192 | 193 | cudnn.benchmark = True 194 | 195 | # Data loading code 196 | traindir = os.path.join(args.data, 'train') 197 | valdir = os.path.join(args.data, 'val') 198 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 199 | std=[0.229, 0.224, 0.225]) 200 | 201 | train_dataset = datasets.ImageFolder( 202 | traindir, 203 | transforms.Compose([ 204 | transforms.RandomResizedCrop(224), 205 | transforms.RandomHorizontalFlip(), 206 | transforms.ToTensor(), 207 | normalize, 208 | ])) 209 | 210 | if args.distributed: 211 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 212 | else: 213 | train_sampler = None 214 | 215 | train_loader = torch.utils.data.DataLoader( 216 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 217 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 218 | 219 | val_loader = torch.utils.data.DataLoader( 220 | datasets.ImageFolder(valdir, transforms.Compose([ 221 | transforms.Resize(256), 222 | transforms.CenterCrop(224), 223 | transforms.ToTensor(), 224 | normalize, 225 | ])), 226 | batch_size=args.batch_size, shuffle=False, 227 | num_workers=args.workers, pin_memory=True) 228 | 229 | if args.evaluate: 230 | validate(val_loader, model, criterion, args) 231 | return 232 | 233 | for epoch in range(args.start_epoch, args.epochs): 234 | if args.distributed: 235 | train_sampler.set_epoch(epoch) 236 | adjust_learning_rate(optimizer, epoch, args) 237 | 238 | # train for one epoch 239 | train(train_loader, model, criterion, optimizer, epoch, args) 240 | 241 | # evaluate on validation set 242 | acc1 = validate(val_loader, model, criterion, args) 243 | 244 | # remember best acc@1 and save checkpoint 245 | is_best = acc1 > best_acc1 246 | best_acc1 = max(acc1, best_acc1) 247 | 248 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 249 | and args.rank % ngpus_per_node == 0): 250 | save_checkpoint({ 251 | 'epoch': epoch + 1, 252 | 'arch': args.arch, 253 | 'state_dict': model.state_dict(), 254 | 'best_acc1': best_acc1, 255 | 'optimizer' : optimizer.state_dict(), 256 | }, is_best) 257 | 258 | 259 | def train(train_loader, model, criterion, optimizer, epoch, args): 260 | batch_time = AverageMeter('Time', ':6.3f') 261 | data_time = AverageMeter('Data', ':6.3f') 262 | losses = AverageMeter('Loss', ':.4e') 263 | top1 = AverageMeter('Acc@1', ':6.2f') 264 | top5 = AverageMeter('Acc@5', ':6.2f') 265 | progress = ProgressMeter(len(train_loader), batch_time, data_time, losses, top1, 266 | top5, prefix="Epoch: [{}]".format(epoch)) 267 | 268 | # switch to train mode 269 | model.train() 270 | 271 | end = time.time() 272 | for i, (input, target) in enumerate(train_loader): 273 | # measure data loading time 274 | data_time.update(time.time() - end) 275 | 276 | if args.gpu is not None: 277 | input = input.cuda(args.gpu, non_blocking=True) 278 | target = target.cuda(args.gpu, non_blocking=True) 279 | 280 | # compute output 281 | output = model(input) 282 | loss = criterion(output, target) 283 | 284 | # measure accuracy and record loss 285 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 286 | losses.update(loss.item(), input.size(0)) 287 | top1.update(acc1[0], input.size(0)) 288 | top5.update(acc5[0], input.size(0)) 289 | 290 | # compute gradient and do SGD step 291 | optimizer.zero_grad() 292 | loss.backward() 293 | optimizer.step() 294 | 295 | # measure elapsed time 296 | batch_time.update(time.time() - end) 297 | end = time.time() 298 | 299 | if i % args.print_freq == 0: 300 | progress.print(i) 301 | 302 | 303 | def validate(val_loader, model, criterion, args): 304 | batch_time = AverageMeter('Time', ':6.3f') 305 | losses = AverageMeter('Loss', ':.4e') 306 | top1 = AverageMeter('Acc@1', ':6.2f') 307 | top5 = AverageMeter('Acc@5', ':6.2f') 308 | progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5, 309 | prefix='Test: ') 310 | 311 | # switch to evaluate mode 312 | model.eval() 313 | 314 | with torch.no_grad(): 315 | end = time.time() 316 | for i, (input, target) in enumerate(val_loader): 317 | if args.gpu is not None: 318 | input = input.cuda(args.gpu, non_blocking=True) 319 | target = target.cuda(args.gpu, non_blocking=True) 320 | 321 | # compute output 322 | output = model(input) 323 | loss = criterion(output, target) 324 | 325 | # measure accuracy and record loss 326 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 327 | losses.update(loss.item(), input.size(0)) 328 | top1.update(acc1[0], input.size(0)) 329 | top5.update(acc5[0], input.size(0)) 330 | 331 | # measure elapsed time 332 | batch_time.update(time.time() - end) 333 | end = time.time() 334 | 335 | if i % args.print_freq == 0: 336 | progress.print(i) 337 | 338 | # TODO: this should also be done with the ProgressMeter 339 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 340 | .format(top1=top1, top5=top5)) 341 | 342 | return top1.avg 343 | 344 | 345 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 346 | torch.save(state, filename) 347 | if is_best: 348 | shutil.copyfile(filename, 'model_best.pth.tar') 349 | 350 | 351 | class AverageMeter(object): 352 | """Computes and stores the average and current value""" 353 | def __init__(self, name, fmt=':f'): 354 | self.name = name 355 | self.fmt = fmt 356 | self.reset() 357 | 358 | def reset(self): 359 | self.val = 0 360 | self.avg = 0 361 | self.sum = 0 362 | self.count = 0 363 | 364 | def update(self, val, n=1): 365 | self.val = val 366 | self.sum += val * n 367 | self.count += n 368 | self.avg = self.sum / self.count 369 | 370 | def __str__(self): 371 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 372 | return fmtstr.format(**self.__dict__) 373 | 374 | 375 | class ProgressMeter(object): 376 | def __init__(self, num_batches, *meters, prefix=""): 377 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 378 | self.meters = meters 379 | self.prefix = prefix 380 | 381 | def print(self, batch): 382 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 383 | entries += [str(meter) for meter in self.meters] 384 | print('\t'.join(entries)) 385 | 386 | def _get_batch_fmtstr(self, num_batches): 387 | num_digits = len(str(num_batches // 1)) 388 | fmt = '{:' + str(num_digits) + 'd}' 389 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 390 | 391 | 392 | def adjust_learning_rate(optimizer, epoch, args): 393 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 394 | lr = args.lr * (0.1 ** (epoch // 30)) 395 | for param_group in optimizer.param_groups: 396 | param_group['lr'] = lr 397 | 398 | 399 | def accuracy(output, target, topk=(1,)): 400 | """Computes the accuracy over the k top predictions for the specified values of k""" 401 | with torch.no_grad(): 402 | maxk = max(topk) 403 | batch_size = target.size(0) 404 | 405 | _, pred = output.topk(maxk, 1, True, True) 406 | pred = pred.t() 407 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 408 | 409 | res = [] 410 | for k in topk: 411 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 412 | res.append(correct_k.mul_(100.0 / batch_size)) 413 | return res 414 | 415 | 416 | if __name__ == '__main__': 417 | main() 418 | --------------------------------------------------------------------------------