├── .gitignore ├── LICENSE ├── README.md ├── cutmix.py ├── dataloader.py ├── figures ├── cutmix.png ├── w_cutmix.png └── wo_cutmix.png ├── resnet_preact.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 hysts 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Implementation of CutMix 2 | 3 | ![](figures/cutmix.png) 4 | 5 | ## Usage 6 | 7 | ``` 8 | $ python train.py --depth 20 --use_cutmix --outdir results 9 | ``` 10 | 11 | 12 | ## Results on CIFAR-10 13 | 14 | | Model | Test Error (median of 3 runs) | Training Time | 15 | |:---------------------------|:-----------------------------:|--------------:| 16 | | WRN-20-4 | 4.56 | 1h22m | 17 | | WRN-20-4, CutMix (alpha=1) | 3.62 | 1h22m | 18 | 19 | * These models were trained for 300 epochs with batch size 128, initial learning rate 0.2, and cosine annealing. 20 | * Test errors reported above are of the last epoch. 21 | * These experiments were done using Tesla V100. 22 | 23 | 24 | ### w/o CutMix 25 | 26 | ```bash 27 | $ python -u train.py --depth 20 --base_channels 64 --base_lr 0.2 --scheduler cosine --seed 7 --outdir results/wo_cutmix/00 28 | ``` 29 | 30 | ![](figures/wo_cutmix.png) 31 | 32 | 33 | ### w/ CutMix 34 | 35 | ```bash 36 | $ python -u train.py --depth 20 --base_channels 64 --base_lr 0.2 --scheduler cosine --seed 7 --use_cutmix --cutmix_alpha 1.0 --outdir results/w_cutmix/00 37 | ``` 38 | 39 | ![](figures/w_cutmix.png) 40 | 41 | 42 | 43 | ## References 44 | 45 | * Yun, Sangdoo, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, and Youngjoon Yoo. "CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features." arXiv preprint arXiv:1905.04899 (2019). [arXiv:1905.04899](https://arxiv.org/abs/1905.04899) 46 | 47 | 48 | -------------------------------------------------------------------------------- /cutmix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def cutmix(batch, alpha): 7 | data, targets = batch 8 | 9 | indices = torch.randperm(data.size(0)) 10 | shuffled_data = data[indices] 11 | shuffled_targets = targets[indices] 12 | 13 | lam = np.random.beta(alpha, alpha) 14 | 15 | image_h, image_w = data.shape[2:] 16 | cx = np.random.uniform(0, image_w) 17 | cy = np.random.uniform(0, image_h) 18 | w = image_w * np.sqrt(1 - lam) 19 | h = image_h * np.sqrt(1 - lam) 20 | x0 = int(np.round(max(cx - w / 2, 0))) 21 | x1 = int(np.round(min(cx + w / 2, image_w))) 22 | y0 = int(np.round(max(cy - h / 2, 0))) 23 | y1 = int(np.round(min(cy + h / 2, image_h))) 24 | 25 | data[:, :, y0:y1, x0:x1] = shuffled_data[:, :, y0:y1, x0:x1] 26 | targets = (targets, shuffled_targets, lam) 27 | 28 | return data, targets 29 | 30 | 31 | class CutMixCollator: 32 | def __init__(self, alpha): 33 | self.alpha = alpha 34 | 35 | def __call__(self, batch): 36 | batch = torch.utils.data.dataloader.default_collate(batch) 37 | batch = cutmix(batch, self.alpha) 38 | return batch 39 | 40 | 41 | class CutMixCriterion: 42 | def __init__(self, reduction): 43 | self.criterion = nn.CrossEntropyLoss(reduction=reduction) 44 | 45 | def __call__(self, preds, targets): 46 | targets1, targets2, lam = targets 47 | return lam * self.criterion( 48 | preds, targets1) + (1 - lam) * self.criterion(preds, targets2) 49 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | 5 | from cutmix import CutMixCollator 6 | 7 | 8 | def get_loader(batch_size, num_workers, config): 9 | mean = np.array([0.4914, 0.4822, 0.4465]) 10 | std = np.array([0.2470, 0.2435, 0.2616]) 11 | 12 | train_transform = torchvision.transforms.Compose([ 13 | torchvision.transforms.RandomCrop(32, padding=4), 14 | torchvision.transforms.RandomHorizontalFlip(), 15 | torchvision.transforms.ToTensor(), 16 | torchvision.transforms.Normalize(mean, std), 17 | ]) 18 | test_transform = torchvision.transforms.Compose([ 19 | torchvision.transforms.ToTensor(), 20 | torchvision.transforms.Normalize(mean, std), 21 | ]) 22 | 23 | if config['use_cutmix']: 24 | collator = CutMixCollator(config['cutmix_alpha']) 25 | else: 26 | collator = torch.utils.data.dataloader.default_collate 27 | 28 | dataset_dir = '~/.torchvision/datasets/CIFAR10' 29 | train_dataset = torchvision.datasets.CIFAR10( 30 | dataset_dir, train=True, transform=train_transform, download=True) 31 | test_dataset = torchvision.datasets.CIFAR10( 32 | dataset_dir, train=False, transform=test_transform, download=True) 33 | 34 | train_loader = torch.utils.data.DataLoader( 35 | train_dataset, 36 | batch_size=batch_size, 37 | shuffle=True, 38 | collate_fn=collator, 39 | num_workers=num_workers, 40 | pin_memory=True, 41 | drop_last=True, 42 | ) 43 | test_loader = torch.utils.data.DataLoader( 44 | test_dataset, 45 | batch_size=batch_size, 46 | num_workers=num_workers, 47 | shuffle=False, 48 | pin_memory=True, 49 | drop_last=False, 50 | ) 51 | return train_loader, test_loader 52 | -------------------------------------------------------------------------------- /figures/cutmix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hysts/pytorch_cutmix/bc63fab1aa5d56733d768de98f08db3f4b0984f2/figures/cutmix.png -------------------------------------------------------------------------------- /figures/w_cutmix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hysts/pytorch_cutmix/bc63fab1aa5d56733d768de98f08db3f4b0984f2/figures/w_cutmix.png -------------------------------------------------------------------------------- /figures/wo_cutmix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hysts/pytorch_cutmix/bc63fab1aa5d56733d768de98f08db3f4b0984f2/figures/wo_cutmix.png -------------------------------------------------------------------------------- /resnet_preact.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def initialize_weights(module): 7 | if isinstance(module, nn.Conv2d): 8 | nn.init.kaiming_normal_(module.weight.data, mode='fan_out') 9 | elif isinstance(module, nn.BatchNorm2d): 10 | module.weight.data.fill_(1) 11 | module.bias.data.zero_() 12 | elif isinstance(module, nn.Linear): 13 | module.bias.data.zero_() 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, 20 | in_channels, 21 | out_channels, 22 | stride, 23 | remove_first_relu, 24 | add_last_bn, 25 | preact=False): 26 | super(BasicBlock, self).__init__() 27 | 28 | self._remove_first_relu = remove_first_relu 29 | self._add_last_bn = add_last_bn 30 | self._preact = preact 31 | 32 | self.bn1 = nn.BatchNorm2d(in_channels) 33 | self.conv1 = nn.Conv2d( 34 | in_channels, 35 | out_channels, 36 | kernel_size=3, 37 | stride=stride, # downsample with first conv 38 | padding=1, 39 | bias=False) 40 | self.bn2 = nn.BatchNorm2d(out_channels) 41 | self.conv2 = nn.Conv2d( 42 | out_channels, 43 | out_channels, 44 | kernel_size=3, 45 | stride=1, 46 | padding=1, 47 | bias=False) 48 | 49 | if add_last_bn: 50 | self.bn3 = nn.BatchNorm2d(out_channels) 51 | 52 | self.shortcut = nn.Sequential() 53 | if in_channels != out_channels: 54 | self.shortcut.add_module( 55 | 'conv', 56 | nn.Conv2d( 57 | in_channels, 58 | out_channels, 59 | kernel_size=1, 60 | stride=stride, # downsample 61 | padding=0, 62 | bias=False)) 63 | 64 | def forward(self, x): 65 | if self._preact: 66 | x = F.relu( 67 | self.bn1(x), inplace=True) # shortcut after preactivation 68 | y = self.conv1(x) 69 | else: 70 | # preactivation only for residual path 71 | y = self.bn1(x) 72 | if not self._remove_first_relu: 73 | y = F.relu(y, inplace=True) 74 | y = self.conv1(y) 75 | 76 | y = F.relu(self.bn2(y), inplace=True) 77 | y = self.conv2(y) 78 | 79 | if self._add_last_bn: 80 | y = self.bn3(y) 81 | 82 | y += self.shortcut(x) 83 | return y 84 | 85 | 86 | class BottleneckBlock(nn.Module): 87 | expansion = 4 88 | 89 | def __init__(self, 90 | in_channels, 91 | out_channels, 92 | stride, 93 | remove_first_relu, 94 | add_last_bn, 95 | preact=False): 96 | super(BottleneckBlock, self).__init__() 97 | 98 | self._remove_first_relu = remove_first_relu 99 | self._add_last_bn = add_last_bn 100 | self._preact = preact 101 | 102 | bottleneck_channels = out_channels // self.expansion 103 | 104 | self.bn1 = nn.BatchNorm2d(in_channels) 105 | self.conv1 = nn.Conv2d( 106 | in_channels, 107 | bottleneck_channels, 108 | kernel_size=1, 109 | stride=1, 110 | padding=0, 111 | bias=False) 112 | self.bn2 = nn.BatchNorm2d(bottleneck_channels) 113 | self.conv2 = nn.Conv2d( 114 | bottleneck_channels, 115 | bottleneck_channels, 116 | kernel_size=3, 117 | stride=stride, # downsample with 3x3 conv 118 | padding=1, 119 | bias=False) 120 | self.bn3 = nn.BatchNorm2d(bottleneck_channels) 121 | self.conv3 = nn.Conv2d( 122 | bottleneck_channels, 123 | out_channels, 124 | kernel_size=1, 125 | stride=1, 126 | padding=0, 127 | bias=False) 128 | 129 | if add_last_bn: 130 | self.bn4 = nn.BatchNorm2d(out_channels) 131 | 132 | self.shortcut = nn.Sequential() # identity 133 | if in_channels != out_channels: 134 | self.shortcut.add_module( 135 | 'conv', 136 | nn.Conv2d( 137 | in_channels, 138 | out_channels, 139 | kernel_size=1, 140 | stride=stride, # downsample 141 | padding=0, 142 | bias=False)) 143 | 144 | def forward(self, x): 145 | if self._preact: 146 | x = F.relu( 147 | self.bn1(x), inplace=True) # shortcut after preactivation 148 | y = self.conv1(x) 149 | else: 150 | # preactivation only for residual path 151 | y = self.bn1(x) 152 | if not self._remove_first_relu: 153 | y = F.relu(y, inplace=True) 154 | y = self.conv1(y) 155 | 156 | y = F.relu(self.bn2(y), inplace=True) 157 | y = self.conv2(y) 158 | y = F.relu(self.bn3(y), inplace=True) 159 | y = self.conv3(y) 160 | 161 | if self._add_last_bn: 162 | y = self.bn4(y) 163 | 164 | y += self.shortcut(x) 165 | return y 166 | 167 | 168 | class Network(nn.Module): 169 | def __init__(self, config): 170 | super(Network, self).__init__() 171 | 172 | input_shape = config['input_shape'] 173 | n_classes = config['n_classes'] 174 | 175 | base_channels = config['base_channels'] 176 | self._remove_first_relu = False 177 | self._add_last_bn = False 178 | block_type = config['block_type'] 179 | depth = config['depth'] 180 | preact_stage = [True, True, True] 181 | 182 | assert block_type in ['basic', 'bottleneck'] 183 | if block_type == 'basic': 184 | block = BasicBlock 185 | n_blocks_per_stage = (depth - 2) // 6 186 | assert n_blocks_per_stage * 6 + 2 == depth 187 | else: 188 | block = BottleneckBlock 189 | n_blocks_per_stage = (depth - 2) // 9 190 | assert n_blocks_per_stage * 9 + 2 == depth 191 | 192 | n_channels = [ 193 | base_channels, 194 | base_channels * 2 * block.expansion, 195 | base_channels * 4 * block.expansion, 196 | ] 197 | 198 | self.conv = nn.Conv2d( 199 | input_shape[1], 200 | n_channels[0], 201 | kernel_size=(3, 3), 202 | stride=1, 203 | padding=1, 204 | bias=False) 205 | 206 | self.stage1 = self._make_stage( 207 | n_channels[0], 208 | n_channels[0], 209 | n_blocks_per_stage, 210 | block, 211 | stride=1, 212 | preact=preact_stage[0]) 213 | self.stage2 = self._make_stage( 214 | n_channels[0], 215 | n_channels[1], 216 | n_blocks_per_stage, 217 | block, 218 | stride=2, 219 | preact=preact_stage[1]) 220 | self.stage3 = self._make_stage( 221 | n_channels[1], 222 | n_channels[2], 223 | n_blocks_per_stage, 224 | block, 225 | stride=2, 226 | preact=preact_stage[2]) 227 | self.bn = nn.BatchNorm2d(n_channels[2]) 228 | 229 | # compute conv feature size 230 | with torch.no_grad(): 231 | self.feature_size = self._forward_conv( 232 | torch.zeros(*input_shape)).view(-1).shape[0] 233 | 234 | self.fc = nn.Linear(self.feature_size, n_classes) 235 | 236 | # initialize weights 237 | self.apply(initialize_weights) 238 | 239 | def _make_stage(self, in_channels, out_channels, n_blocks, block, stride, 240 | preact): 241 | stage = nn.Sequential() 242 | for index in range(n_blocks): 243 | block_name = 'block{}'.format(index + 1) 244 | if index == 0: 245 | stage.add_module( 246 | block_name, 247 | block( 248 | in_channels, 249 | out_channels, 250 | stride=stride, 251 | remove_first_relu=self._remove_first_relu, 252 | add_last_bn=self._add_last_bn, 253 | preact=preact)) 254 | else: 255 | stage.add_module( 256 | block_name, 257 | block( 258 | out_channels, 259 | out_channels, 260 | stride=1, 261 | remove_first_relu=self._remove_first_relu, 262 | add_last_bn=self._add_last_bn, 263 | preact=False)) 264 | return stage 265 | 266 | def _forward_conv(self, x): 267 | x = self.conv(x) 268 | x = self.stage1(x) 269 | x = self.stage2(x) 270 | x = self.stage3(x) 271 | x = F.relu( 272 | self.bn(x), 273 | inplace=True) # apply BN and ReLU before average pooling 274 | x = F.adaptive_avg_pool2d(x, output_size=1) 275 | return x 276 | 277 | def forward(self, x): 278 | x = self._forward_conv(x) 279 | x = x.view(x.size(0), -1) 280 | x = self.fc(x) 281 | return x 282 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from collections import OrderedDict 4 | import argparse 5 | import importlib 6 | import json 7 | import logging 8 | import pathlib 9 | import random 10 | import time 11 | import numpy as np 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torchvision 16 | try: 17 | from tensorboardX import SummaryWriter 18 | is_tensorboard_available = True 19 | except Exception: 20 | is_tensorboard_available = False 21 | 22 | from dataloader import get_loader 23 | from cutmix import CutMixCriterion 24 | 25 | torch.backends.cudnn.benchmark = True 26 | 27 | logging.basicConfig( 28 | format='[%(asctime)s %(name)s %(levelname)s] - %(message)s', 29 | datefmt='%Y/%m/%d %H:%M:%S', 30 | level=logging.DEBUG) 31 | logger = logging.getLogger(__name__) 32 | 33 | global_step = 0 34 | 35 | 36 | def str2bool(s): 37 | if s.lower() == 'true': 38 | return True 39 | elif s.lower() == 'false': 40 | return False 41 | else: 42 | raise RuntimeError('Boolean value expected') 43 | 44 | 45 | def parse_args(): 46 | parser = argparse.ArgumentParser() 47 | # model config 48 | parser.add_argument( 49 | '--block_type', 50 | type=str, 51 | default='basic', 52 | choices=['basic', 'bottleneck']) 53 | parser.add_argument('--depth', type=int, required=True) 54 | parser.add_argument('--base_channels', type=int, default=16) 55 | 56 | # cutmix 57 | parser.add_argument('--use_cutmix', action='store_true') 58 | parser.add_argument('--cutmix_alpha', type=float, default=1.0) 59 | 60 | # run config 61 | parser.add_argument('--outdir', type=str, required=True) 62 | parser.add_argument('--seed', type=int, default=0) 63 | parser.add_argument('--num_workers', type=int, default=4) 64 | parser.add_argument('--device', type=str, default='cuda') 65 | 66 | # optim config 67 | parser.add_argument('--epochs', type=int, default=300) 68 | parser.add_argument('--batch_size', type=int, default=128) 69 | parser.add_argument('--base_lr', type=float, default=0.2) 70 | parser.add_argument('--weight_decay', type=float, default=1e-4) 71 | parser.add_argument('--momentum', type=float, default=0.9) 72 | parser.add_argument('--nesterov', type=str2bool, default=True) 73 | parser.add_argument( 74 | '--scheduler', 75 | type=str, 76 | default='cosine', 77 | choices=['multistep', 'cosine']) 78 | parser.add_argument('--milestones', type=str, default='[150, 225]') 79 | parser.add_argument('--lr_decay', type=float, default=0.1) 80 | 81 | # TensorBoard 82 | parser.add_argument( 83 | '--no-tensorboard', dest='tensorboard', action='store_false') 84 | 85 | args = parser.parse_args() 86 | if not is_tensorboard_available: 87 | args.tensorboard = False 88 | 89 | model_config = OrderedDict([ 90 | ('arch', 'resnet_preact'), 91 | ('block_type', args.block_type), 92 | ('depth', args.depth), 93 | ('base_channels', args.base_channels), 94 | ('input_shape', (1, 3, 32, 32)), 95 | ('n_classes', 10), 96 | ]) 97 | 98 | optim_config = OrderedDict([ 99 | ('epochs', args.epochs), 100 | ('batch_size', args.batch_size), 101 | ('base_lr', args.base_lr), 102 | ('weight_decay', args.weight_decay), 103 | ('momentum', args.momentum), 104 | ('nesterov', args.nesterov), 105 | ('scheduler', args.scheduler), 106 | ('milestones', json.loads(args.milestones)), 107 | ('lr_decay', args.lr_decay), 108 | ]) 109 | 110 | data_config = OrderedDict([ 111 | ('dataset', 'CIFAR10'), 112 | ('use_cutmix', args.use_cutmix), 113 | ('cutmix_alpha', args.cutmix_alpha), 114 | ]) 115 | 116 | run_config = OrderedDict([ 117 | ('seed', args.seed), 118 | ('outdir', args.outdir), 119 | ('num_workers', args.num_workers), 120 | ('device', args.device), 121 | ('tensorboard', args.tensorboard), 122 | ]) 123 | 124 | config = OrderedDict([ 125 | ('model_config', model_config), 126 | ('optim_config', optim_config), 127 | ('data_config', data_config), 128 | ('run_config', run_config), 129 | ]) 130 | 131 | return config 132 | 133 | 134 | def load_model(config): 135 | module = importlib.import_module(config['arch']) 136 | Network = getattr(module, 'Network') 137 | return Network(config) 138 | 139 | 140 | class AverageMeter: 141 | def __init__(self): 142 | self.reset() 143 | 144 | def reset(self): 145 | self.val = 0 146 | self.avg = 0 147 | self.sum = 0 148 | self.count = 0 149 | 150 | def update(self, val, num): 151 | self.val = val 152 | self.sum += val * num 153 | self.count += num 154 | self.avg = self.sum / self.count 155 | 156 | 157 | def train(epoch, model, optimizer, criterion, train_loader, run_config, 158 | writer): 159 | global global_step 160 | 161 | logger.info('Train {}'.format(epoch)) 162 | 163 | model.train() 164 | device = torch.device(run_config['device']) 165 | 166 | loss_meter = AverageMeter() 167 | accuracy_meter = AverageMeter() 168 | start = time.time() 169 | for step, (data, targets) in enumerate(train_loader): 170 | global_step += 1 171 | 172 | if run_config['tensorboard'] and step == 0: 173 | image = torchvision.utils.make_grid( 174 | data, normalize=True, scale_each=True) 175 | writer.add_image('Train/Image', image, epoch) 176 | 177 | data = data.to(device) 178 | if isinstance(targets, (tuple, list)): 179 | targets1, targets2, lam = targets 180 | targets = (targets1.to(device), targets2.to(device), lam) 181 | else: 182 | targets = targets.to(device) 183 | 184 | optimizer.zero_grad() 185 | 186 | outputs = model(data) 187 | loss = criterion(outputs, targets) 188 | loss.backward() 189 | 190 | optimizer.step() 191 | 192 | _, preds = torch.max(outputs, dim=1) 193 | 194 | loss_ = loss.item() 195 | 196 | num = data.size(0) 197 | if isinstance(targets, (tuple, list)): 198 | targets1, targets2, lam = targets 199 | correct1 = preds.eq(targets1).sum().item() 200 | correct2 = preds.eq(targets2).sum().item() 201 | accuracy = (lam * correct1 + (1 - lam) * correct2) / num 202 | else: 203 | correct_ = preds.eq(targets).sum().item() 204 | accuracy = correct_ / num 205 | 206 | loss_meter.update(loss_, num) 207 | accuracy_meter.update(accuracy, num) 208 | 209 | if run_config['tensorboard']: 210 | writer.add_scalar('Train/RunningLoss', loss_, global_step) 211 | writer.add_scalar('Train/RunningAccuracy', accuracy, global_step) 212 | 213 | if step % 100 == 0: 214 | logger.info('Epoch {} Step {}/{} ' 215 | 'Loss {:.4f} ({:.4f}) ' 216 | 'Accuracy {:.4f} ({:.4f})'.format( 217 | epoch, 218 | step, 219 | len(train_loader), 220 | loss_meter.val, 221 | loss_meter.avg, 222 | accuracy_meter.val, 223 | accuracy_meter.avg, 224 | )) 225 | 226 | elapsed = time.time() - start 227 | logger.info('Elapsed {:.2f}'.format(elapsed)) 228 | 229 | if run_config['tensorboard']: 230 | writer.add_scalar('Train/Loss', loss_meter.avg, epoch) 231 | writer.add_scalar('Train/Accuracy', accuracy_meter.avg, epoch) 232 | writer.add_scalar('Train/Time', elapsed, epoch) 233 | 234 | train_log = OrderedDict({ 235 | 'epoch': 236 | epoch, 237 | 'train': 238 | OrderedDict({ 239 | 'loss': loss_meter.avg, 240 | 'accuracy': accuracy_meter.avg, 241 | 'time': elapsed, 242 | }), 243 | }) 244 | return train_log 245 | 246 | 247 | def test(epoch, model, criterion, test_loader, run_config, writer): 248 | logger.info('Test {}'.format(epoch)) 249 | 250 | model.eval() 251 | device = torch.device(run_config['device']) 252 | 253 | loss_meter = AverageMeter() 254 | correct_meter = AverageMeter() 255 | start = time.time() 256 | with torch.no_grad(): 257 | for step, (data, targets) in enumerate(test_loader): 258 | if run_config['tensorboard'] and epoch == 0 and step == 0: 259 | image = torchvision.utils.make_grid( 260 | data, normalize=True, scale_each=True) 261 | writer.add_image('Test/Image', image, epoch) 262 | 263 | data = data.to(device) 264 | targets = targets.to(device) 265 | 266 | outputs = model(data) 267 | loss = criterion(outputs, targets) 268 | 269 | _, preds = torch.max(outputs, dim=1) 270 | 271 | loss_ = loss.item() 272 | correct_ = preds.eq(targets).sum().item() 273 | num = data.size(0) 274 | 275 | loss_meter.update(loss_, num) 276 | correct_meter.update(correct_, 1) 277 | 278 | accuracy = correct_meter.sum / len(test_loader.dataset) 279 | 280 | logger.info('Epoch {} Loss {:.4f} Accuracy {:.4f}'.format( 281 | epoch, loss_meter.avg, accuracy)) 282 | 283 | elapsed = time.time() - start 284 | logger.info('Elapsed {:.2f}'.format(elapsed)) 285 | 286 | if run_config['tensorboard']: 287 | if epoch > 0: 288 | writer.add_scalar('Test/Loss', loss_meter.avg, epoch) 289 | writer.add_scalar('Test/Accuracy', accuracy, epoch) 290 | writer.add_scalar('Test/Time', elapsed, epoch) 291 | 292 | for name, param in model.named_parameters(): 293 | writer.add_histogram(name, param, global_step) 294 | 295 | test_log = OrderedDict({ 296 | 'epoch': 297 | epoch, 298 | 'test': 299 | OrderedDict({ 300 | 'loss': loss_meter.avg, 301 | 'accuracy': accuracy, 302 | 'time': elapsed, 303 | }), 304 | }) 305 | return test_log 306 | 307 | 308 | def main(): 309 | # parse command line arguments 310 | config = parse_args() 311 | logger.info(json.dumps(config, indent=2)) 312 | 313 | run_config = config['run_config'] 314 | optim_config = config['optim_config'] 315 | data_config = config['data_config'] 316 | 317 | # set random seed 318 | seed = run_config['seed'] 319 | torch.manual_seed(seed) 320 | np.random.seed(seed) 321 | random.seed(seed) 322 | 323 | # create output directory 324 | outdir = pathlib.Path(run_config['outdir']) 325 | outdir.mkdir(exist_ok=True, parents=True) 326 | 327 | # TensorBoard SummaryWriter 328 | writer = SummaryWriter( 329 | outdir.as_posix()) if run_config['tensorboard'] else None 330 | 331 | # save config as json file in output directory 332 | outpath = outdir / 'config.json' 333 | with open(outpath, 'w') as fout: 334 | json.dump(config, fout, indent=2) 335 | 336 | # data loaders 337 | train_loader, test_loader = get_loader( 338 | optim_config['batch_size'], run_config['num_workers'], data_config) 339 | 340 | # model 341 | model = load_model(config['model_config']) 342 | model.to(torch.device(run_config['device'])) 343 | n_params = sum([param.view(-1).size()[0] for param in model.parameters()]) 344 | logger.info('n_params: {}'.format(n_params)) 345 | 346 | if data_config['use_cutmix']: 347 | train_criterion = CutMixCriterion(reduction='mean') 348 | else: 349 | train_criterion = nn.CrossEntropyLoss(reduction='mean') 350 | test_criterion = nn.CrossEntropyLoss(reduction='mean') 351 | 352 | # optimizer 353 | optimizer = torch.optim.SGD( 354 | model.parameters(), 355 | lr=optim_config['base_lr'], 356 | momentum=optim_config['momentum'], 357 | weight_decay=optim_config['weight_decay'], 358 | nesterov=optim_config['nesterov']) 359 | if optim_config['scheduler'] == 'multistep': 360 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 361 | optimizer, 362 | milestones=optim_config['milestones'], 363 | gamma=optim_config['lr_decay']) 364 | else: 365 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 366 | optimizer, optim_config['epochs'], 0) 367 | 368 | # run test before start training 369 | test(0, model, test_criterion, test_loader, run_config, writer) 370 | 371 | epoch_logs = [] 372 | for epoch in range(1, optim_config['epochs'] + 1): 373 | scheduler.step() 374 | 375 | train_log = train(epoch, model, optimizer, train_criterion, 376 | train_loader, run_config, writer) 377 | test_log = test(epoch, model, test_criterion, test_loader, run_config, 378 | writer) 379 | 380 | epoch_log = train_log.copy() 381 | epoch_log.update(test_log) 382 | epoch_logs.append(epoch_log) 383 | with open(outdir / 'log.json', 'w') as fout: 384 | json.dump(epoch_logs, fout, indent=2) 385 | 386 | state = OrderedDict([ 387 | ('config', config), 388 | ('state_dict', model.state_dict()), 389 | ('optimizer', optimizer.state_dict()), 390 | ('epoch', epoch), 391 | ('accuracy', test_log['test']['accuracy']), 392 | ]) 393 | model_path = outdir / 'model_state.pth' 394 | torch.save(state, model_path) 395 | 396 | 397 | if __name__ == '__main__': 398 | main() 399 | --------------------------------------------------------------------------------