├── LICENSE ├── README.md ├── assets └── data_folder.PNG ├── learning ├── evaluator.py ├── lr_scheduler.py └── trainer.py ├── main.py ├── network └── resnet.py ├── option.py ├── requirements.txt └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Hoseong Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # swa-tutorials-pytorch 3 | [Stochastic Weight Averaging](https://arxiv.org/abs/1803.05407) Tutorials using pytorch. Based on [PyTorch 1.6 Official Features (Stochastic Weight Averaging)](https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/), implement classification codebase using custom dataset. 4 | 5 | - author: hoya012 6 | - last update: 2020.10.23 7 | 8 | ## 0. Experimental Setup 9 | ### 0-1. Prepare Library 10 | - Need to install PyTorch and Captum 11 | 12 | ```python 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | ### 0-2. Download dataset (Kaggle Intel Image Classification) 17 | 18 | - [Intel Image Classification](https://www.kaggle.com/puneet6060/intel-image-classification/) 19 | 20 | This Data contains around 25k images of size 150x150 distributed under 6 categories. 21 | {'buildings' -> 0, 22 | 'forest' -> 1, 23 | 'glacier' -> 2, 24 | 'mountain' -> 3, 25 | 'sea' -> 4, 26 | 'street' -> 5 } 27 | 28 | - Make `data` folder and move dataset into `data` folder. 29 | 30 |

31 | 32 |

33 | 34 | ### 1. Baseline Training 35 | - ImageNet Pretrained ResNet-18 from torchvision.models 36 | - Batch Size 256 / Epochs 120 / Initial Learning Rate 0.0001 37 | - Training Augmentation: Resize((256, 256)), RandomHorizontalFlip() 38 | - Adam + Cosine Learning rate scheduling with warmup 39 | - I tried NVIDIA Pascal GPU - GTX 1080 Ti 1 GPU 40 | 41 | ```python 42 | python main.py --checkpoint_name baseline; 43 | ``` 44 | 45 | ### 2. Stochastic Weight Averaging Training 46 | 47 | In PyTorch 1.6, Stochastic Weight Averaging is very easy to use! Thanks to PyTorch.. 48 | 49 | - PyTorch's official tutorial's guide 50 | ```python 51 | from torch.optim.swa_utils import AveragedModel, SWALR 52 | from torch.optim.lr_scheduler import CosineAnnealingLR 53 | 54 | loader, optimizer, model, loss_fn = ... 55 | swa_model = AveragedModel(model) 56 | scheduler = CosineAnnealingLR(optimizer, T_max=100) 57 | swa_start = 5 58 | swa_scheduler = SWALR(optimizer, swa_lr=0.05) 59 | 60 | for epoch in range(100): 61 | for input, target in loader: 62 | optimizer.zero_grad() 63 | loss_fn(model(input), target).backward() 64 | optimizer.step() 65 | if epoch > swa_start: 66 | swa_model.update_parameters(model) 67 | swa_scheduler.step() 68 | else: 69 | scheduler.step() 70 | 71 | # Update bn statistics for the swa_model at the end 72 | torch.optim.swa_utils.update_bn(loader, swa_model) 73 | # Use swa_model to make predictions on test data 74 | preds = swa_model(test_input) 75 | ``` 76 | 77 | - My own implementations 78 | ```python 79 | # in main.py 80 | """ define model and learning rate scheduler for stochastic weight averaging """ 81 | swa_model = torch.optim.swa_utils.AveragedModel(model) 82 | swa_scheduler = SWALR(optimizer, swa_lr=args.swa_lr) 83 | 84 | ... 85 | 86 | # in learning/trainer.py 87 | for batch_idx, (inputs, labels) in enumerate(data_loader): 88 | if not args.decay_type == 'swa': 89 | self.scheduler.step() 90 | else: 91 | if epoch <= args.swa_start: 92 | self.scheduler.step() 93 | 94 | if epoch > args.swa_start and args.decay_type == 'swa': 95 | self.swa_model.update_parameters(self.model) 96 | self.swa_scheduler.step() 97 | 98 | ... 99 | 100 | # in main.py 101 | swa_model = swa_model.cpu() 102 | torch.optim.swa_utils.update_bn(train_loader, swa_model) 103 | swa_model = swa_model.cuda() 104 | ``` 105 | 106 | #### Run Script (Command Line) 107 | ```python 108 | python main.py --checkpoint_name swa --decay_type swa --swa_start 90 --swa_lr 5e-5; 109 | ``` 110 | 111 | ### 3. Performance Table 112 | - B : Baseline 113 | - SWA : Stochastic Weight Averaging 114 | - SWA_{swa_start}_{swa_lr} 115 | 116 | | Algorithm | Test Accuracy | 117 | |:------------:|:-------------:| 118 | | B | 94.10 | 119 | | SWA_90_0.05 | 80.53 | 120 | | SWA_90_1e-4 | 94.20 | 121 | | SWA_90_5e-4 | 93.87 | 122 | | SWA_90_1e-5 | 94.23 | 123 | | SWA_90_5e-5 | **94.57** | 124 | | SWA_75_5e-5 | 94.27 | 125 | | SWA_60_5e-5 | 94.33 | 126 | 127 | ### 4. Code Reference 128 | - Baseline Code: https://github.com/hoya012/carrier-of-tricks-for-classification-pytorch 129 | - Gradual Warmup Scheduler: https://github.com/ildoonet/pytorch-gradual-warmup-lr 130 | - PyTorch Stochastic Weight Averaging: https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging -------------------------------------------------------------------------------- /assets/data_folder.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoya012/swa-tutorials-pytorch/ab8cace9a1457a74cadfe6671598237a4d389cc0/assets/data_folder.PNG -------------------------------------------------------------------------------- /learning/evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | from utils import AverageMeter, accuracy 5 | 6 | class Evaluator(): 7 | def __init__(self, model, criterion, swa_model=None): 8 | self.model = model 9 | self.criterion = criterion 10 | self.swa_model = swa_model 11 | self.save_path = os.path.join(self.model.checkpoint_dir, self.model.checkpoint_name, 'result_dict.json') 12 | if not os.path.exists(os.path.join(self.model.checkpoint_dir, self.model.checkpoint_name)): 13 | os.makedirs(os.path.join(self.model.checkpoint_dir, self.model.checkpoint_name)) 14 | 15 | def worst_result(self): 16 | ret = { 17 | 'loss': float('inf'), 18 | 'accuracy': 0.0 19 | } 20 | return ret 21 | 22 | def result_to_str(self, result): 23 | ret = [ 24 | 'epoch: {epoch:0>3}', 25 | 'loss: {loss: >4.2e}' 26 | ] 27 | for metric in self.evaluation_metrics: 28 | ret.append('{}: {}'.format(metric.name, metric.fmtstr)) 29 | return ', '.join(ret).format(**result) 30 | 31 | def save(self, result): 32 | with open(self.save_path, 'w') as f: 33 | f.write(json.dumps(result, sort_keys=True, indent=4, ensure_ascii=False)) 34 | 35 | def load(self): 36 | result = self.worst_result 37 | if os.path.exists(self.save_path): 38 | with open(self.save_path, 'r') as f: 39 | try: 40 | result = json.loads(f.read()) 41 | except: 42 | pass 43 | return result 44 | 45 | def evaluate(self, data_loader, epoch, args, result_dict): 46 | losses = AverageMeter() 47 | top1 = AverageMeter() 48 | 49 | self.model.eval() 50 | 51 | total_loss = 0 52 | with torch.no_grad(): 53 | for batch_idx, (inputs, labels) in enumerate(data_loader): 54 | inputs, labels = inputs.cuda(), labels.cuda() 55 | if args.amp: 56 | with torch.cuda.amp.autocast(): 57 | outputs = self.model(inputs) 58 | loss = self.criterion(outputs, labels) 59 | else: 60 | outputs = self.model(inputs) 61 | loss = self.criterion(outputs, labels) 62 | 63 | prec1, prec3 = accuracy(outputs.data, labels, topk=(1, 3)) 64 | losses.update(loss.item(), inputs.size(0)) 65 | top1.update(prec1.item(), inputs.size(0)) 66 | 67 | print('----Validation Results Summary----') 68 | print('Epoch: [{}] Top-1 accuracy: {:.2f}%'.format(epoch, top1.avg)) 69 | 70 | result_dict['val_loss'].append(losses.avg) 71 | result_dict['val_acc'].append(top1.avg) 72 | 73 | return result_dict 74 | 75 | def test(self, data_loader, args, result_dict): 76 | top1 = AverageMeter() 77 | 78 | if args.decay_type == 'swa': 79 | self.model = self.swa_model 80 | 81 | self.model.eval() 82 | with torch.no_grad(): 83 | for batch_idx, (inputs, labels) in enumerate(data_loader): 84 | inputs, labels = inputs.cuda(), labels.cuda() 85 | outputs = self.model(inputs) 86 | 87 | prec1, prec3 = accuracy(outputs.data, labels, topk=(1, 3)) 88 | top1.update(prec1.item(), inputs.size(0)) 89 | 90 | print('----Test Set Results Summary----') 91 | print('Top-1 accuracy: {:.2f}%'.format(top1.avg)) 92 | 93 | result_dict['test_acc'].append(top1.avg) 94 | 95 | return result_dict 96 | 97 | -------------------------------------------------------------------------------- /learning/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | """ 6 | Gradually warm-up(increasing) learning rate in optimizer. 7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 8 | 9 | Code Reference: https://github.com/ildoonet/pytorch-gradual-warmup-lr 10 | """ 11 | 12 | class GradualWarmupScheduler(_LRScheduler): 13 | """ Gradually warm-up(increasing) learning rate in optimizer. 14 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 15 | Args: 16 | optimizer (Optimizer): Wrapped optimizer. 17 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 18 | total_epoch: target learning rate is reached at total_epoch, gradually 19 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 20 | """ 21 | 22 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 23 | self.multiplier = multiplier 24 | if self.multiplier < 1.: 25 | raise ValueError('multiplier should be greater thant or equal to 1.') 26 | self.total_epoch = total_epoch 27 | self.after_scheduler = after_scheduler 28 | self.finished = False 29 | super(GradualWarmupScheduler, self).__init__(optimizer) 30 | 31 | def get_lr(self): 32 | if self.last_epoch > self.total_epoch: 33 | if self.after_scheduler: 34 | if not self.finished: 35 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 36 | self.finished = True 37 | return self.after_scheduler.get_last_lr() 38 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 39 | 40 | if self.multiplier == 1.0: 41 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 42 | else: 43 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 44 | 45 | def step(self, epoch=None, metrics=None): 46 | if self.finished and self.after_scheduler: 47 | if epoch is None: 48 | self.after_scheduler.step(None) 49 | else: 50 | self.after_scheduler.step(epoch - self.total_epoch) 51 | self._last_lr = self.after_scheduler.get_last_lr() 52 | else: 53 | return super(GradualWarmupScheduler, self).step(epoch) 54 | -------------------------------------------------------------------------------- /learning/trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from utils import AverageMeter, accuracy 4 | 5 | class Trainer: 6 | def __init__(self, model, criterion, optimizer, scheduler, scaler, swa_model=None, swa_scheduler=None): 7 | self.model = model 8 | self.criterion = criterion 9 | self.optimizer = optimizer 10 | self.scheduler = scheduler 11 | self.scaler = scaler 12 | self.swa_model = swa_model 13 | self.swa_scheduler = swa_scheduler 14 | 15 | def train(self, data_loader, epoch, args, result_dict): 16 | total_loss = 0 17 | count = 0 18 | 19 | losses = AverageMeter() 20 | top1 = AverageMeter() 21 | 22 | self.model.train() 23 | 24 | for batch_idx, (inputs, labels) in enumerate(data_loader): 25 | inputs, labels = inputs.cuda(), labels.cuda() 26 | 27 | if args.amp: 28 | with torch.cuda.amp.autocast(): 29 | outputs = self.model(inputs) 30 | loss = self.criterion(outputs, labels) 31 | else: 32 | outputs = self.model(inputs) 33 | loss = self.criterion(outputs, labels) 34 | 35 | if len(labels.size()) > 1: 36 | labels = torch.argmax(labels, axis=1) 37 | 38 | prec1, prec3 = accuracy(outputs.data, labels, topk=(1, 3)) 39 | losses.update(loss.item(), inputs.size(0)) 40 | top1.update(prec1.item(), inputs.size(0)) 41 | 42 | self.optimizer.zero_grad() 43 | 44 | if args.amp: 45 | self.scaler.scale(loss).backward() 46 | self.scaler.step(self.optimizer) 47 | self.scaler.update() 48 | else: 49 | loss.backward() 50 | self.optimizer.step() 51 | 52 | 53 | total_loss += loss.tolist() 54 | count += labels.size(0) 55 | 56 | if batch_idx % args.log_interval == 0: 57 | _s = str(len(str(len(data_loader.sampler)))) 58 | ret = [ 59 | ('epoch: {:0>3} [{: >' + _s + '}/{} ({: >3.0f}%)]').format(epoch, count, len(data_loader.sampler), 100 * count / len(data_loader.sampler)), 60 | 'train_loss: {: >4.2e}'.format(total_loss / count), 61 | 'train_accuracy : {:.2f}%'.format(top1.avg) 62 | ] 63 | print(', '.join(ret)) 64 | 65 | if not args.decay_type == 'swa': 66 | self.scheduler.step() 67 | else: 68 | if epoch <= args.swa_start: 69 | self.scheduler.step() 70 | 71 | if epoch > args.swa_start and args.decay_type == 'swa': 72 | self.swa_model.update_parameters(self.model) 73 | self.swa_scheduler.step() 74 | 75 | result_dict['train_loss'].append(losses.avg) 76 | result_dict['train_acc'].append(top1.avg) 77 | 78 | return result_dict -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os, sys, time 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torchvision 6 | from torch.optim.swa_utils import SWALR 7 | 8 | PATH = os.path.dirname(os.path.abspath(__file__)) 9 | sys.path.insert(0, PATH + '/../..') 10 | 11 | from option import get_args 12 | from learning.trainer import Trainer 13 | from learning.evaluator import Evaluator 14 | from utils import get_model, make_optimizer, make_scheduler, make_dataloader, plot_learning_curves 15 | 16 | def main(): 17 | args = get_args() 18 | torch.manual_seed(args.seed) 19 | torch.backends.cudnn.benchmark = True 20 | 21 | shape = (224,224,3) 22 | 23 | """ define dataloader """ 24 | train_loader, valid_loader, test_loader = make_dataloader(args) 25 | 26 | """ define model architecture """ 27 | model = get_model(args, shape, args.num_classes) 28 | 29 | if torch.cuda.device_count() >= 1: 30 | print('Model pushed to {} GPU(s), type {}.'.format(torch.cuda.device_count(), torch.cuda.get_device_name(0))) 31 | model = model.cuda() 32 | else: 33 | raise ValueError('CPU training is not supported') 34 | 35 | """ define loss criterion """ 36 | criterion = nn.CrossEntropyLoss().cuda() 37 | 38 | """ define optimizer """ 39 | optimizer = make_optimizer(args, model) 40 | 41 | """ define learning rate scheduler """ 42 | scheduler = make_scheduler(args, optimizer) 43 | 44 | """ define loss scaler for automatic mixed precision """ 45 | scaler = torch.cuda.amp.GradScaler() 46 | 47 | """ define model and learning rate scheduler for stochastic weight averaging """ 48 | swa_model = torch.optim.swa_utils.AveragedModel(model) 49 | swa_scheduler = SWALR(optimizer, swa_lr=args.swa_lr) 50 | 51 | """ define trainer, evaluator, result_dictionary """ 52 | result_dict = {'args':vars(args), 'epoch':[], 'train_loss' : [], 'train_acc' : [], 'val_loss' : [], 'val_acc' : [], 'test_acc':[]} 53 | trainer = Trainer(model, criterion, optimizer, scheduler, scaler, swa_model, swa_scheduler) 54 | evaluator = Evaluator(model, criterion, swa_model) 55 | 56 | train_time_list = [] 57 | valid_time_list = [] 58 | 59 | if args.evaluate: 60 | """ load model checkpoint """ 61 | model.load() 62 | result_dict = evaluator.test(test_loader, args, result_dict) 63 | else: 64 | evaluator.save(result_dict) 65 | 66 | best_val_acc = 0.0 67 | """ define training loop """ 68 | for epoch in range(args.epochs): 69 | result_dict['epoch'] = epoch 70 | 71 | torch.cuda.synchronize() 72 | tic1 = time.time() 73 | 74 | result_dict = trainer.train(train_loader, epoch, args, result_dict) 75 | 76 | torch.cuda.synchronize() 77 | tic2 = time.time() 78 | train_time_list.append(tic2 - tic1) 79 | 80 | torch.cuda.synchronize() 81 | tic3 = time.time() 82 | 83 | result_dict = evaluator.evaluate(valid_loader, epoch, args, result_dict) 84 | 85 | torch.cuda.synchronize() 86 | tic4 = time.time() 87 | valid_time_list.append(tic4 - tic3) 88 | 89 | if result_dict['val_acc'][-1] > best_val_acc: 90 | print("{} epoch, best epoch was updated! {}%".format(epoch, result_dict['val_acc'][-1])) 91 | best_val_acc = result_dict['val_acc'][-1] 92 | model.save(checkpoint_name='best_model') 93 | 94 | evaluator.save(result_dict) 95 | plot_learning_curves(result_dict, epoch, args) 96 | 97 | if args.decay_type == 'swa': 98 | swa_model = swa_model.cpu() 99 | torch.optim.swa_utils.update_bn(train_loader, swa_model) 100 | swa_model = swa_model.cuda() 101 | 102 | result_dict = evaluator.test(test_loader, args, result_dict) 103 | evaluator.save(result_dict) 104 | 105 | """ calculate test accuracy using best model """ 106 | model.load(checkpoint_name='best_model') 107 | result_dict = evaluator.test(test_loader, args, result_dict) 108 | evaluator.save(result_dict) 109 | 110 | print(result_dict) 111 | 112 | np.savetxt(os.path.join(model.checkpoint_dir, model.checkpoint_name, 'train_time_amp.csv'), train_time_list, delimiter=',', fmt='%s') 113 | np.savetxt(os.path.join(model.checkpoint_dir, model.checkpoint_name, 'valid_time_amp.csv'), valid_time_list, delimiter=',', fmt='%s') 114 | 115 | if __name__ == '__main__': 116 | main() -------------------------------------------------------------------------------- /network/resnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torchvision.models import resnet 5 | 6 | __all__ = [ 7 | 'ResNet18', 'ResNet50' 8 | ] 9 | 10 | class Flatten(nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def forward(self, x): 15 | return x.view(x.size(0), -1) 16 | 17 | def conv3x3(in_channel, out_channel, stride=1): 18 | return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False) 19 | 20 | def conv1x1(in_channel, out_channel, stride=1): 21 | return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, bias=False) 22 | 23 | class ResNet(nn.Module): 24 | def __init__(self, block, layer_config, num_classes=2, norm='batch', zero_init_residual=False): 25 | super(ResNet, self).__init__() 26 | norm = nn.BatchNorm2d 27 | 28 | self.in_channel = 64 29 | 30 | self.conv = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2, padding=3, bias=False) 31 | self.norm = norm(self.in_channel) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 34 | self.layer1 = self.make_layer(block, 64*block.expansion, layer_config[0], stride=1, norm=norm) 35 | self.layer2 = self.make_layer(block, 128*block.expansion, layer_config[1], stride=2, norm=norm) 36 | self.layer3 = self.make_layer(block, 256*block.expansion, layer_config[2], stride=2, norm=norm) 37 | self.layer4 = self.make_layer(block, 512*block.expansion, layer_config[3], stride=2, norm=norm) 38 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 39 | self.flatten = Flatten() 40 | self.dense = nn.Linear(512*block.expansion, num_classes) 41 | 42 | for m in self.modules(): 43 | if isinstance(m, nn.Conv2d): 44 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 45 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 46 | nn.init.constant_(m.weight, 1) 47 | nn.init.constant_(m.bias, 0) 48 | 49 | # Zero-initialize the last BN in each residual branch, 50 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 51 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 52 | if zero_init_residual: 53 | for m in self.modules(): 54 | if isinstance(m, BasicBlock): 55 | nn.init.constant_(m.norm2.weight, 0) 56 | elif isinstance(m, Bottleneck): 57 | nn.init.constant_(m.norm3.weight, 0) 58 | 59 | def make_layer(self, block, out_channel, num_blocks, stride=1, norm=None): 60 | norm = nn.BatchNorm2d 61 | 62 | downsample = None 63 | if stride != 1 or self.in_channel != out_channel: 64 | downsample = nn.Sequential( 65 | conv1x1(self.in_channel, out_channel, stride), 66 | norm(out_channel), 67 | ) 68 | layers = [] 69 | layers.append(block(self.in_channel, out_channel, stride, downsample, norm)) 70 | self.in_channel = out_channel 71 | for _ in range(1, num_blocks): 72 | layers.append(block(self.in_channel, out_channel, norm=norm)) 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = self.conv(x) 77 | out = self.norm(out) 78 | out = self.relu(out) 79 | out = self.maxpool(out) 80 | out = self.layer1(out) 81 | out = self.layer2(out) 82 | out = self.layer3(out) 83 | out = self.layer4(out) 84 | out = self.avgpool(out) 85 | out = self.flatten(out) 86 | out = self.dense(out) 87 | return out 88 | 89 | 90 | class BasicBlock(nn.Module): 91 | expansion = 1 92 | 93 | def __init__(self, in_channel, out_channel, stride=1, downsample=None, norm=None): 94 | super(BasicBlock, self).__init__() 95 | norm = nn.BatchNorm2d 96 | 97 | self.conv1 = conv3x3(in_channel, out_channel, stride) 98 | self.norm1 = norm(out_channel) 99 | self.conv2 = conv3x3(out_channel, out_channel) 100 | self.norm2 = norm(out_channel) 101 | self.relu = nn.ReLU(inplace=True) 102 | self.downsample = downsample 103 | 104 | def forward(self, x): 105 | identity = x 106 | out = self.conv1(x) 107 | out = self.norm1(out) 108 | out = self.relu(out) 109 | out = self.conv2(out) 110 | out = self.norm2(out) 111 | if self.downsample is not None: 112 | identity = self.downsample(identity) 113 | out += identity 114 | out = self.relu(out) 115 | return out 116 | 117 | class Bottleneck(nn.Module): 118 | expansion = 4 119 | 120 | def __init__(self, in_channel, out_channel, stride=1, downsample=None, norm=None): 121 | super(Bottleneck, self).__init__() 122 | norm = nn.BatchNorm2d 123 | 124 | mid_channel= out_channel // self.expansion 125 | self.conv1 = conv1x1(in_channel, mid_channel) 126 | self.norm1 = norm(mid_channel) 127 | self.conv2 = conv3x3(mid_channel, mid_channel, stride) 128 | self.norm2 = norm(mid_channel) 129 | self.conv3 = conv1x1(mid_channel, out_channel) 130 | self.norm3 = norm(out_channel) 131 | self.relu = nn.ReLU(inplace=True) 132 | self.downsample = downsample 133 | 134 | def forward(self, x): 135 | identity = x 136 | out = self.conv1(x) 137 | out = self.norm1(out) 138 | out = self.relu(out) 139 | out = self.conv2(out) 140 | out = self.norm2(out) 141 | out = self.relu(out) 142 | out = self.conv3(out) 143 | out = self.norm3(out) 144 | if self.downsample is not None: 145 | identity = self.downsample(x) 146 | out += identity 147 | out = self.relu(out) 148 | return out 149 | 150 | 151 | class ResNet18(nn.Module): 152 | def __init__(self, shape, num_classes=2, checkpoint_dir='checkpoint', checkpoint_name='ResNet18', 153 | pretrained=False, pretrained_path=None, norm='batch', zero_init_residual=False): 154 | super(ResNet18, self).__init__() 155 | 156 | if len(shape) != 3: 157 | raise ValueError('Invalid shape: {}'.format(shape)) 158 | self.shape = shape 159 | self.num_classes = num_classes 160 | self.checkpoint_dir = checkpoint_dir 161 | self.checkpoint_name = checkpoint_name 162 | self.H, self.W, self.C = shape 163 | 164 | if not os.path.exists(checkpoint_dir): 165 | os.makedirs(checkpoint_dir) 166 | self.checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name, 'model.pt') 167 | 168 | model = ResNet(BasicBlock, [2,2,2, 2], num_classes, norm, zero_init_residual) 169 | 170 | if pretrained: 171 | print("Pretrained Weight is loaded!!") 172 | if pretrained_path is None: 173 | print("Loading from torchvision models") 174 | model = resnet.resnet18(pretrained=True) 175 | if zero_init_residual: 176 | for m in model.modules(): 177 | if isinstance(m, resnet.Bottleneck): 178 | nn.init.constant_(m.bn3.weight, 0) 179 | elif isinstance(m, BasicBlock): 180 | nn.init.constant_(m.bn2.weight, 0) 181 | else: 182 | checkpoint = torch.load(pretrained_path) 183 | model.load_state_dict(checkpoint) 184 | 185 | self.features = nn.Sequential(*list(model.children())[:-2]) 186 | self.num_features = 512 * BasicBlock.expansion 187 | self.classifier = nn.Sequential( 188 | nn.AdaptiveAvgPool2d((1, 1)), 189 | Flatten(), 190 | nn.Linear(self.num_features, num_classes) 191 | ) 192 | 193 | def save(self, checkpoint_name=''): 194 | if checkpoint_name == '': 195 | torch.save(self.state_dict(), self.checkpoint_path) 196 | else: 197 | checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name, checkpoint_name + '.pt') 198 | torch.save(self.state_dict(), checkpoint_path) 199 | 200 | def load(self, checkpoint_name=''): 201 | if checkpoint_name == '': 202 | self.load_state_dict(torch.load(self.checkpoint_path)) 203 | else: 204 | checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name, checkpoint_name + '.pt') 205 | self.load_state_dict(torch.load(checkpoint_path)) 206 | 207 | def forward(self, x): 208 | out = x 209 | out = self.features(out) 210 | out = self.classifier(out) 211 | return out 212 | 213 | class ResNet50(nn.Module): 214 | def __init__(self, shape, num_classes=2, checkpoint_dir='checkpoint', checkpoint_name='ResNet50', 215 | pretrained=False, pretrained_path=None, norm='batch', zero_init_residual=False): 216 | super(ResNet50, self).__init__() 217 | 218 | if len(shape) != 3: 219 | raise ValueError('Invalid shape: {}'.format(shape)) 220 | self.shape = shape 221 | self.num_classes = num_classes 222 | self.checkpoint_dir = checkpoint_dir 223 | self.checkpoint_name = checkpoint_name 224 | self.H, self.W, self.C = shape 225 | 226 | if not os.path.exists(checkpoint_dir): 227 | os.makedirs(checkpoint_dir) 228 | self.checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name, 'model.pt') 229 | 230 | model = ResNet(Bottleneck, [3,4,6,3], num_classes, norm, zero_init_residual) 231 | 232 | if pretrained: 233 | print("Pretrained Weight is loaded!!") 234 | if pretrained_path is None: 235 | print("Loading from torchvision models") 236 | model = resnet.resnet50(pretrained=True) 237 | if zero_init_residual: 238 | for m in model.modules(): 239 | if isinstance(m, resnet.Bottleneck): 240 | nn.init.constant_(m.bn3.weight, 0) 241 | else: 242 | checkpoint = torch.load(pretrained_path) 243 | model.load_state_dict(checkpoint) 244 | 245 | self.features = nn.Sequential(*list(model.children())[:-2]) 246 | self.num_features = 512 * Bottleneck.expansion 247 | self.classifier = nn.Sequential( 248 | nn.AdaptiveAvgPool2d((1, 1)), 249 | Flatten(), 250 | nn.Linear(self.num_features, num_classes) 251 | ) 252 | 253 | def save(self, checkpoint_name=''): 254 | if checkpoint_name == '': 255 | torch.save(self.state_dict(), self.checkpoint_path) 256 | else: 257 | checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name, checkpoint_name + '.pt') 258 | torch.save(self.state_dict(), checkpoint_path) 259 | 260 | def load(self, checkpoint_name=''): 261 | if checkpoint_name == '': 262 | self.load_state_dict(torch.load(self.checkpoint_path)) 263 | else: 264 | checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name, checkpoint_name + '.pt') 265 | self.load_state_dict(torch.load(checkpoint_path)) 266 | 267 | def forward(self, x): 268 | out = x 269 | out = self.features(out) 270 | out = self.classifier(out) 271 | return out -------------------------------------------------------------------------------- /option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args(): 4 | parser = argparse.ArgumentParser() 5 | 6 | # model architecture & checkpoint 7 | parser.add_argument('--model', default='ResNet18', choices=('ResNet18', 'ResNet50'), 8 | help='optimizer to use (ResNet18 | ResNet50)') 9 | parser.add_argument('--norm', default='batchnorm') 10 | parser.add_argument('--num_classes', type=int, default=6) 11 | parser.add_argument('--pretrained', type=int, default=1) 12 | parser.add_argument('--pretrained_path', type=str, default=None) 13 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint') 14 | parser.add_argument('--checkpoint_name', type=str, default='') 15 | 16 | # data loading 17 | parser.add_argument('--num_workers', type=int, default=16) 18 | parser.add_argument('--seed', type=int, default=42, help='random seed') 19 | 20 | # training hyper parameters 21 | parser.add_argument('--batch_size', type=int, default=256) 22 | parser.add_argument('--epochs', type=int, default=120) 23 | parser.add_argument('--log_interval', type=int, default=20) 24 | parser.add_argument('--evaluate', action='store_true', default=False) 25 | parser.add_argument('--amp', action='store_true', default=False) 26 | 27 | # optimzier & learning rate scheduler 28 | parser.add_argument('--learning_rate', type=float, default=0.0001) 29 | parser.add_argument('--weight_decay', type=float, default=0.0001) 30 | parser.add_argument('--optimizer', default='ADAM', choices=('SGD', 'ADAM'), 31 | help='optimizer to use (SGD | ADAM)') 32 | parser.add_argument('--decay_type', default='cosine_warmup', choices=('step', 'step_warmup', 'cosine_warmup', 'swa'), 33 | help='optimizer to use (step | step_warmup | cosine_warmup | stochastic weight averaging)') 34 | parser.add_argument('--swa_start', type=int, default=90) 35 | parser.add_argument('--swa_lr', type=float, default=0.05) 36 | 37 | 38 | args = parser.parse_args() 39 | return args 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | captum==0.2.0 2 | cycler==0.10.0 3 | decorator==4.4.2 4 | decord==0.4.0 5 | future==0.18.2 6 | imageio==2.8.0 7 | joblib==0.16.0 8 | kiwisolver==1.2.0 9 | matplotlib==3.2.1 10 | networkx==2.4 11 | numpy==1.18.5 12 | opencv-python==4.2.0.34 13 | pandas==1.0.4 14 | Pillow==7.1.2 15 | pyparsing==2.4.7 16 | python-dateutil==2.8.1 17 | pytz==2020.1 18 | PyWavelets==1.1.1 19 | scikit-image==0.17.2 20 | scikit-learn==0.23.1 21 | scipy==1.4.1 22 | six==1.15.0 23 | threadpoolctl==2.1.0 24 | tifffile==2020.6.3 25 | torch==1.6.0 26 | torchsummary==1.5.1 27 | torchvision==0.7.0 28 | tqdm==4.48.2 29 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | import matplotlib.pyplot as plt 5 | from sklearn.model_selection import train_test_split 6 | 7 | import torch 8 | import torch.optim as optim 9 | import torch.optim.lr_scheduler as lrs 10 | from torch.utils.data.sampler import SubsetRandomSampler 11 | import torchvision 12 | from torchvision import transforms as T 13 | 14 | from network.resnet import * 15 | from learning.lr_scheduler import GradualWarmupScheduler 16 | 17 | def get_model(args, shape, num_classes): 18 | model = eval(args.model)( 19 | shape, 20 | num_classes, 21 | checkpoint_dir=args.checkpoint_dir, 22 | checkpoint_name=args.checkpoint_name, 23 | pretrained=args.pretrained, 24 | pretrained_path=args.pretrained_path, 25 | norm=args.norm, 26 | ) 27 | return model 28 | 29 | 30 | def make_optimizer(args, model): 31 | trainable = filter(lambda x: x.requires_grad, model.parameters()) 32 | 33 | if args.optimizer == 'SGD': 34 | optimizer_function = optim.SGD 35 | kwargs = {'momentum': 0.9} 36 | elif args.optimizer == 'ADAM': 37 | optimizer_function = optim.Adam 38 | kwargs = { 39 | 'betas': (0.9, 0.999), 40 | 'eps': 1e-08 41 | } 42 | else: 43 | raise NameError('Not Supportes Optimizer') 44 | 45 | kwargs['lr'] = args.learning_rate 46 | kwargs['weight_decay'] = args.weight_decay 47 | 48 | return optimizer_function(trainable, **kwargs) 49 | 50 | 51 | def make_scheduler(args, optimizer): 52 | if args.decay_type == 'step': 53 | scheduler = lrs.MultiStepLR( 54 | optimizer, 55 | milestones=[30, 60, 90], 56 | gamma=0.1 57 | ) 58 | elif args.decay_type == 'step_warmup': 59 | scheduler = lrs.MultiStepLR( 60 | optimizer, 61 | milestones=[30, 60, 90], 62 | gamma=0.1 63 | ) 64 | scheduler = GradualWarmupScheduler( 65 | optimizer, 66 | multiplier=1, 67 | total_epoch=5, 68 | after_scheduler=scheduler 69 | ) 70 | elif args.decay_type == 'cosine_warmup': 71 | cosine_scheduler = lrs.CosineAnnealingLR( 72 | optimizer, 73 | T_max=args.epochs 74 | ) 75 | scheduler = GradualWarmupScheduler( 76 | optimizer, 77 | multiplier=1, 78 | total_epoch=args.epochs//10, 79 | after_scheduler=cosine_scheduler 80 | ) 81 | elif args.decay_type == 'swa': 82 | scheduler = lrs.CosineAnnealingLR( 83 | optimizer, 84 | T_max=args.epochs 85 | ) 86 | else: 87 | raise Exception('unknown lr scheduler: {}'.format(args.decay_type)) 88 | 89 | return scheduler 90 | 91 | def make_dataloader(args): 92 | 93 | train_trans = T.Compose([ 94 | T.Resize((256, 256)), 95 | T.RandomHorizontalFlip(), 96 | T.ToTensor(), 97 | ]) 98 | 99 | valid_trans = T.Compose([ 100 | T.Resize((256, 256)), 101 | T.ToTensor(), 102 | ]) 103 | 104 | test_trans = T.Compose([ 105 | T.Resize((256, 256)), 106 | T.ToTensor(), 107 | ]) 108 | 109 | trainset = torchvision.datasets.ImageFolder(root="data/seg_train/seg_train", transform=train_trans) 110 | validset = torchvision.datasets.ImageFolder(root="data/seg_train/seg_train", transform=valid_trans) 111 | testset = torchvision.datasets.ImageFolder(root="data/seg_test/seg_test", transform=test_trans) 112 | 113 | np.random.seed(args.seed) 114 | targets = trainset.targets 115 | train_idx, valid_idx = train_test_split(np.arange(len(targets)), test_size=0.2, shuffle=True, stratify=targets) 116 | train_sampler = SubsetRandomSampler(train_idx) 117 | valid_sampler = SubsetRandomSampler(valid_idx) 118 | 119 | train_loader = torch.utils.data.DataLoader( 120 | trainset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.num_workers 121 | ) 122 | 123 | valid_loader = torch.utils.data.DataLoader( 124 | validset, batch_size=args.batch_size, sampler=valid_sampler, num_workers=args.num_workers 125 | ) 126 | 127 | test_loader = torch.utils.data.DataLoader( 128 | testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers 129 | ) 130 | 131 | return train_loader, valid_loader, test_loader 132 | 133 | 134 | def plot_learning_curves(metrics, cur_epoch, args): 135 | x = np.arange(cur_epoch+1) 136 | fig, ax1 = plt.subplots() 137 | ax1.set_xlabel('epochs') 138 | ax1.set_ylabel('loss') 139 | ln1 = ax1.plot(x, metrics['train_loss'], color='tab:red') 140 | ln2 = ax1.plot(x, metrics['val_loss'], color='tab:red', linestyle='dashed') 141 | ax1.grid() 142 | ax2 = ax1.twinx() 143 | ax2.set_ylabel('accuracy') 144 | ln3 = ax2.plot(x, metrics['train_acc'], color='tab:blue') 145 | ln4 = ax2.plot(x, metrics['val_acc'], color='tab:blue', linestyle='dashed') 146 | lns = ln1+ln2+ln3+ln4 147 | plt.legend(lns, ['Train loss', 'Validation loss', 'Train accuracy','Validation accuracy']) 148 | plt.tight_layout() 149 | plt.savefig('{}/{}/learning_curve.png'.format(args.checkpoint_dir, args.checkpoint_name), bbox_inches='tight') 150 | plt.close('all') 151 | 152 | class AverageMeter(object): 153 | """Computes and stores the average and current value""" 154 | def __init__(self): 155 | self.reset() 156 | 157 | def reset(self): 158 | self.val = 0 159 | self.avg = 0 160 | self.sum = 0 161 | self.count = 0 162 | self.max = 0 163 | self.min = 1e5 164 | 165 | def update(self, val, n=1): 166 | self.val = val 167 | self.sum += val * n 168 | self.count += n 169 | self.avg = self.sum / self.count 170 | if val > self.max: 171 | self.max = val 172 | if val < self.min: 173 | self.min = val 174 | 175 | def accuracy(output, target, topk=(1,)): 176 | """Computes the precision@k for the specified values of k""" 177 | maxk = max(topk) 178 | batch_size = target.size(0) 179 | 180 | _, pred = output.topk(maxk, 1, True, True) 181 | pred = pred.t() 182 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 183 | 184 | res = [] 185 | for k in topk: 186 | correct_k = correct[:k].view(-1).float().sum(0) 187 | res.append(correct_k.mul_(100.0 / batch_size)) 188 | return res --------------------------------------------------------------------------------