├── LICENSE ├── README.md ├── assets ├── banner.PNG └── data_folder.PNG ├── learning ├── evaluator.py ├── lr_scheduler.py └── trainer.py ├── main.py ├── network └── resnet.py ├── option.py ├── requirements.txt └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Hoseong Lee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |

3 | 4 |

5 | 6 | # automatic-mixed-precision-tutorials-pytorch 7 | Automatic Mixed Precision Tutorials using pytorch. Based on [PyTorch 1.6 Official Features (Automatic Mixed Precision)](https://pytorch.org/docs/stable/notes/amp_examples.html), implement classification codebase using custom dataset. 8 | 9 | - author: hoya012 10 | - last update: 2020.09.03 11 | - [supplementary materials 1 (blog post written in Korean)](https://hoya012.github.io/blog/Mixed-Precision-Training/) 12 | - [supplementary materials 2 (blog post written in Korean)](https://hoya012.github.io/blog/Image-Classification-with-Mixed-Precision-Training-PyTorch-Tutorial/) 13 | 14 | ## 0. Experimental Setup (I used 1 GTX 1080 Ti GPU and 1 RTX 2080 Ti GPU!) 15 | ### 0-1. Prepare Library 16 | - Must use Newest PyTorch version. (>= 1.6.0) 17 | 18 | ```python 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | ### 0-2. Download dataset (Kaggle Intel Image Classification) 23 | 24 | - [Intel Image Classification](https://www.kaggle.com/puneet6060/intel-image-classification/) 25 | 26 | This Data contains around 25k images of size 150x150 distributed under 6 categories. 27 | {'buildings' -> 0, 28 | 'forest' -> 1, 29 | 'glacier' -> 2, 30 | 'mountain' -> 3, 31 | 'sea' -> 4, 32 | 'street' -> 5 } 33 | 34 | - Make `data` folder and move dataset into `data` folder. 35 | 36 |

37 | 38 |

39 | 40 | ### 1. Baseline Training 41 | - ImageNet Pretrained ResNet-18 from torchvision.models 42 | - Batch Size 256 / Epochs 120 / Initial Learning Rate 0.0001 43 | - Training Augmentation: Resize((256, 256)), RandomHorizontalFlip() 44 | - Adam + Cosine Learning rate scheduling with warmup 45 | - I tried NVIDIA Pascal GPU - GTX 1080 Ti 1 GPU (w/o Tensor Core) and NVIDIA Turing GPU - RTX 2080 Ti 1 GPU (with Tensor Core) 46 | 47 | ```python 48 | python main.py --checkpoint_name baseline; 49 | ``` 50 | 51 | ### 2. Automatic Mixed Precision Training 52 | 53 | In PyTorch 1.6, Automatic Mixed Precision Training is very easy to use! Thanks to PyTorch! 54 | 55 | #### 2.1 Before 56 | ```python 57 | for batch_idx, (inputs, labels) in enumerate(data_loader): 58 | self.optimizer.zero_grad() 59 | 60 | outputs = self.model(inputs) 61 | loss = self.criterion(outputs, labels) 62 | 63 | loss.backward() 64 | self.optimizer.step() 65 | ``` 66 | #### 2.2 After (just add 5 lines) 67 | 68 | ```python 69 | """ define loss scaler for automatic mixed precision """ 70 | scaler = torch.cuda.amp.GradScaler() 71 | 72 | for batch_idx, (inputs, labels) in enumerate(data_loader): 73 | self.optimizer.zero_grad() 74 | 75 | with torch.cuda.amp.autocast(): 76 | outputs = self.model(inputs) 77 | loss = self.criterion(outputs, labels) 78 | 79 | # Scales the loss, and calls backward() 80 | # to create scaled gradients 81 | self.scaler.scale(loss).backward() 82 | 83 | # Unscales gradients and calls 84 | # or skips optimizer.step() 85 | self.scaler.step(self.optimizer) 86 | 87 | # Updates the scale for next iteration 88 | self.scaler.update() 89 | ``` 90 | 91 | #### 2.3 Run Script (Command Line) 92 | ```python 93 | python main.py --checkpoint_name baseline_amp --amp; 94 | ``` 95 | 96 | ### 3. Performance Table 97 | - B : Baseline (FP32) 98 | - AMP : Automatic Mixed Precision Training (AMP) 99 | 100 | | Algorithm | Test Accuracy | GPU Memory | Total Training Time | 101 | |:------------:|:-------------:|:--------------:|:-------------------:| 102 | | B - 1080 Ti | 94.13 | 10737MB | 64.9m | 103 | | B - 2080 Ti | 94.17 | 10855MB | 54.3m | 104 | | AMP - 1080 Ti| 94.07 | 6615MB | 64.7m | 105 | | AMP - 2080 Ti| 94.23 | 7799MB | 37.3m | 106 | 107 | ### 4. Code Reference 108 | - Baseline Code: https://github.com/hoya012/carrier-of-tricks-for-classification-pytorch 109 | - Gradual Warmup Scheduler: https://github.com/ildoonet/pytorch-gradual-warmup-lr 110 | - PyTorch Automatic Mixed Precision: https://pytorch.org/docs/stable/notes/amp_examples.html 111 | -------------------------------------------------------------------------------- /assets/banner.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoya012/automatic-mixed-precision-tutorials-pytorch/53f785ee07f5f09078014f7817b87487ea497eb2/assets/banner.PNG -------------------------------------------------------------------------------- /assets/data_folder.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hoya012/automatic-mixed-precision-tutorials-pytorch/53f785ee07f5f09078014f7817b87487ea497eb2/assets/data_folder.PNG -------------------------------------------------------------------------------- /learning/evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | from utils import AverageMeter, accuracy 5 | 6 | class Evaluator(): 7 | def __init__(self, model, criterion): 8 | self.model = model 9 | self.criterion = criterion 10 | self.save_path = os.path.join(self.model.checkpoint_dir, self.model.checkpoint_name, 'result_dict.json') 11 | if not os.path.exists(os.path.join(self.model.checkpoint_dir, self.model.checkpoint_name)): 12 | os.makedirs(os.path.join(self.model.checkpoint_dir, self.model.checkpoint_name)) 13 | 14 | def worst_result(self): 15 | ret = { 16 | 'loss': float('inf'), 17 | 'accuracy': 0.0 18 | } 19 | return ret 20 | 21 | def result_to_str(self, result): 22 | ret = [ 23 | 'epoch: {epoch:0>3}', 24 | 'loss: {loss: >4.2e}' 25 | ] 26 | for metric in self.evaluation_metrics: 27 | ret.append('{}: {}'.format(metric.name, metric.fmtstr)) 28 | return ', '.join(ret).format(**result) 29 | 30 | def save(self, result): 31 | with open(self.save_path, 'w') as f: 32 | f.write(json.dumps(result, sort_keys=True, indent=4, ensure_ascii=False)) 33 | 34 | def load(self): 35 | result = self.worst_result 36 | if os.path.exists(self.save_path): 37 | with open(self.save_path, 'r') as f: 38 | try: 39 | result = json.loads(f.read()) 40 | except: 41 | pass 42 | return result 43 | 44 | def evaluate(self, data_loader, epoch, args, result_dict): 45 | losses = AverageMeter() 46 | top1 = AverageMeter() 47 | 48 | self.model.eval() 49 | total_loss = 0 50 | with torch.no_grad(): 51 | for batch_idx, (inputs, labels) in enumerate(data_loader): 52 | inputs, labels = inputs.cuda(), labels.cuda() 53 | if args.amp: 54 | with torch.cuda.amp.autocast(): 55 | outputs = self.model(inputs) 56 | loss = self.criterion(outputs, labels) 57 | else: 58 | outputs = self.model(inputs) 59 | loss = self.criterion(outputs, labels) 60 | 61 | prec1, prec3 = accuracy(outputs.data, labels, topk=(1, 3)) 62 | losses.update(loss.item(), inputs.size(0)) 63 | top1.update(prec1.item(), inputs.size(0)) 64 | 65 | print('----Validation Results Summary----') 66 | print('Epoch: [{}] Top-1 accuracy: {:.2f}%'.format(epoch, top1.avg)) 67 | 68 | result_dict['val_loss'].append(losses.avg) 69 | result_dict['val_acc'].append(top1.avg) 70 | 71 | return result_dict 72 | 73 | def test(self, data_loader, args, result_dict): 74 | top1 = AverageMeter() 75 | 76 | self.model.eval() 77 | with torch.no_grad(): 78 | for batch_idx, (inputs, labels) in enumerate(data_loader): 79 | inputs, labels = inputs.cuda(), labels.cuda() 80 | outputs = self.model(inputs) 81 | 82 | prec1, prec3 = accuracy(outputs.data, labels, topk=(1, 3)) 83 | top1.update(prec1.item(), inputs.size(0)) 84 | 85 | print('----Test Set Results Summary----') 86 | print('Top-1 accuracy: {:.2f}%'.format(top1.avg)) 87 | 88 | result_dict['test_acc'].append(top1.avg) 89 | 90 | return result_dict 91 | 92 | -------------------------------------------------------------------------------- /learning/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | """ 6 | Gradually warm-up(increasing) learning rate in optimizer. 7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 8 | 9 | Code Reference: https://github.com/ildoonet/pytorch-gradual-warmup-lr 10 | """ 11 | 12 | class GradualWarmupScheduler(_LRScheduler): 13 | """ Gradually warm-up(increasing) learning rate in optimizer. 14 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 15 | Args: 16 | optimizer (Optimizer): Wrapped optimizer. 17 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 18 | total_epoch: target learning rate is reached at total_epoch, gradually 19 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 20 | """ 21 | 22 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 23 | self.multiplier = multiplier 24 | if self.multiplier < 1.: 25 | raise ValueError('multiplier should be greater thant or equal to 1.') 26 | self.total_epoch = total_epoch 27 | self.after_scheduler = after_scheduler 28 | self.finished = False 29 | super(GradualWarmupScheduler, self).__init__(optimizer) 30 | 31 | def get_lr(self): 32 | if self.last_epoch > self.total_epoch: 33 | if self.after_scheduler: 34 | if not self.finished: 35 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 36 | self.finished = True 37 | return self.after_scheduler.get_last_lr() 38 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 39 | 40 | if self.multiplier == 1.0: 41 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 42 | else: 43 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 44 | 45 | def step(self, epoch=None, metrics=None): 46 | if self.finished and self.after_scheduler: 47 | if epoch is None: 48 | self.after_scheduler.step(None) 49 | else: 50 | self.after_scheduler.step(epoch - self.total_epoch) 51 | self._last_lr = self.after_scheduler.get_last_lr() 52 | else: 53 | return super(GradualWarmupScheduler, self).step(epoch) 54 | -------------------------------------------------------------------------------- /learning/trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from utils import AverageMeter, accuracy 4 | 5 | class Trainer: 6 | def __init__(self, model, criterion, optimizer, scheduler, scaler): 7 | self.model = model 8 | self.criterion = criterion 9 | self.optimizer = optimizer 10 | self.scheduler = scheduler 11 | self.scaler = scaler 12 | 13 | def train(self, data_loader, epoch, args, result_dict): 14 | total_loss = 0 15 | count = 0 16 | 17 | losses = AverageMeter() 18 | top1 = AverageMeter() 19 | 20 | self.model.train() 21 | 22 | for batch_idx, (inputs, labels) in enumerate(data_loader): 23 | inputs, labels = inputs.cuda(), labels.cuda() 24 | 25 | if args.amp: 26 | with torch.cuda.amp.autocast(): 27 | outputs = self.model(inputs) 28 | loss = self.criterion(outputs, labels) 29 | else: 30 | outputs = self.model(inputs) 31 | loss = self.criterion(outputs, labels) 32 | 33 | if len(labels.size()) > 1: 34 | labels = torch.argmax(labels, axis=1) 35 | 36 | prec1, prec3 = accuracy(outputs.data, labels, topk=(1, 3)) 37 | losses.update(loss.item(), inputs.size(0)) 38 | top1.update(prec1.item(), inputs.size(0)) 39 | 40 | self.optimizer.zero_grad() 41 | 42 | if args.amp: 43 | self.scaler.scale(loss).backward() 44 | self.scaler.step(self.optimizer) 45 | self.scaler.update() 46 | else: 47 | loss.backward() 48 | self.optimizer.step() 49 | total_loss += loss.tolist() 50 | count += labels.size(0) 51 | 52 | if batch_idx % args.log_interval == 0: 53 | _s = str(len(str(len(data_loader.sampler)))) 54 | ret = [ 55 | ('epoch: {:0>3} [{: >' + _s + '}/{} ({: >3.0f}%)]').format(epoch, count, len(data_loader.sampler), 100 * count / len(data_loader.sampler)), 56 | 'train_loss: {: >4.2e}'.format(total_loss / count), 57 | 'train_accuracy : {:.2f}%'.format(top1.avg) 58 | ] 59 | print(', '.join(ret)) 60 | 61 | self.scheduler.step() 62 | result_dict['train_loss'].append(losses.avg) 63 | result_dict['train_acc'].append(top1.avg) 64 | 65 | return result_dict -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os, sys, time 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torchvision 6 | 7 | PATH = os.path.dirname(os.path.abspath(__file__)) 8 | sys.path.insert(0, PATH + '/../..') 9 | 10 | from option import get_args 11 | from learning.trainer import Trainer 12 | from learning.evaluator import Evaluator 13 | from utils import get_model, make_optimizer, make_scheduler, make_dataloader, plot_learning_curves 14 | 15 | def main(): 16 | args = get_args() 17 | torch.manual_seed(args.seed) 18 | 19 | shape = (224,224,3) 20 | 21 | """ define dataloader """ 22 | train_loader, valid_loader, test_loader = make_dataloader(args) 23 | 24 | """ define model architecture """ 25 | model = get_model(args, shape, args.num_classes) 26 | 27 | if torch.cuda.device_count() >= 1: 28 | print('Model pushed to {} GPU(s), type {}.'.format(torch.cuda.device_count(), torch.cuda.get_device_name(0))) 29 | model = model.cuda() 30 | else: 31 | raise ValueError('CPU training is not supported') 32 | 33 | """ define loss criterion """ 34 | criterion = nn.CrossEntropyLoss().cuda() 35 | 36 | """ define optimizer """ 37 | optimizer = make_optimizer(args, model) 38 | 39 | """ define learning rate scheduler """ 40 | scheduler = make_scheduler(args, optimizer) 41 | 42 | """ define loss scaler for automatic mixed precision """ 43 | scaler = torch.cuda.amp.GradScaler() 44 | 45 | """ define trainer, evaluator, result_dictionary """ 46 | result_dict = {'args':vars(args), 'epoch':[], 'train_loss' : [], 'train_acc' : [], 'val_loss' : [], 'val_acc' : [], 'test_acc':[]} 47 | trainer = Trainer(model, criterion, optimizer, scheduler, scaler) 48 | evaluator = Evaluator(model, criterion) 49 | 50 | train_time_list = [] 51 | valid_time_list = [] 52 | 53 | if args.evaluate: 54 | """ load model checkpoint """ 55 | model.load() 56 | result_dict = evaluator.test(test_loader, args, result_dict) 57 | else: 58 | evaluator.save(result_dict) 59 | 60 | best_val_acc = 0.0 61 | """ define training loop """ 62 | for epoch in range(args.epochs): 63 | result_dict['epoch'] = epoch 64 | 65 | torch.cuda.synchronize() 66 | tic1 = time.time() 67 | 68 | result_dict = trainer.train(train_loader, epoch, args, result_dict) 69 | 70 | torch.cuda.synchronize() 71 | tic2 = time.time() 72 | train_time_list.append(tic2 - tic1) 73 | 74 | torch.cuda.synchronize() 75 | tic3 = time.time() 76 | 77 | result_dict = evaluator.evaluate(valid_loader, epoch, args, result_dict) 78 | 79 | torch.cuda.synchronize() 80 | tic4 = time.time() 81 | valid_time_list.append(tic4 - tic3) 82 | 83 | if result_dict['val_acc'][-1] > best_val_acc: 84 | print("{} epoch, best epoch was updated! {}%".format(epoch, result_dict['val_acc'][-1])) 85 | best_val_acc = result_dict['val_acc'][-1] 86 | model.save(checkpoint_name='best_model') 87 | 88 | evaluator.save(result_dict) 89 | plot_learning_curves(result_dict, epoch, args) 90 | 91 | result_dict = evaluator.test(test_loader, args, result_dict) 92 | evaluator.save(result_dict) 93 | 94 | """ calculate test accuracy using best model """ 95 | model.load(checkpoint_name='best_model') 96 | result_dict = evaluator.test(test_loader, args, result_dict) 97 | evaluator.save(result_dict) 98 | 99 | print(result_dict) 100 | 101 | np.savetxt(os.path.join(model.checkpoint_dir, model.checkpoint_name, 'train_time_amp.csv'), train_time_list, delimiter=',', fmt='%s') 102 | np.savetxt(os.path.join(model.checkpoint_dir, model.checkpoint_name, 'valid_time_amp.csv'), valid_time_list, delimiter=',', fmt='%s') 103 | 104 | if __name__ == '__main__': 105 | main() -------------------------------------------------------------------------------- /network/resnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torchvision.models import resnet 5 | 6 | __all__ = [ 7 | 'ResNet18', 'ResNet50' 8 | ] 9 | 10 | class Flatten(nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def forward(self, x): 15 | return x.view(x.size(0), -1) 16 | 17 | def conv3x3(in_channel, out_channel, stride=1): 18 | return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False) 19 | 20 | def conv1x1(in_channel, out_channel, stride=1): 21 | return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, bias=False) 22 | 23 | class ResNet(nn.Module): 24 | def __init__(self, block, layer_config, num_classes=2, norm='batch', zero_init_residual=False): 25 | super(ResNet, self).__init__() 26 | norm = nn.BatchNorm2d 27 | 28 | self.in_channel = 64 29 | 30 | self.conv = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2, padding=3, bias=False) 31 | self.norm = norm(self.in_channel) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 34 | self.layer1 = self.make_layer(block, 64*block.expansion, layer_config[0], stride=1, norm=norm) 35 | self.layer2 = self.make_layer(block, 128*block.expansion, layer_config[1], stride=2, norm=norm) 36 | self.layer3 = self.make_layer(block, 256*block.expansion, layer_config[2], stride=2, norm=norm) 37 | self.layer4 = self.make_layer(block, 512*block.expansion, layer_config[3], stride=2, norm=norm) 38 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 39 | self.flatten = Flatten() 40 | self.dense = nn.Linear(512*block.expansion, num_classes) 41 | 42 | for m in self.modules(): 43 | if isinstance(m, nn.Conv2d): 44 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 45 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 46 | nn.init.constant_(m.weight, 1) 47 | nn.init.constant_(m.bias, 0) 48 | 49 | # Zero-initialize the last BN in each residual branch, 50 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 51 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 52 | if zero_init_residual: 53 | for m in self.modules(): 54 | if isinstance(m, BasicBlock): 55 | nn.init.constant_(m.norm2.weight, 0) 56 | elif isinstance(m, Bottleneck): 57 | nn.init.constant_(m.norm3.weight, 0) 58 | 59 | def make_layer(self, block, out_channel, num_blocks, stride=1, norm=None): 60 | norm = nn.BatchNorm2d 61 | 62 | downsample = None 63 | if stride != 1 or self.in_channel != out_channel: 64 | downsample = nn.Sequential( 65 | conv1x1(self.in_channel, out_channel, stride), 66 | norm(out_channel), 67 | ) 68 | layers = [] 69 | layers.append(block(self.in_channel, out_channel, stride, downsample, norm)) 70 | self.in_channel = out_channel 71 | for _ in range(1, num_blocks): 72 | layers.append(block(self.in_channel, out_channel, norm=norm)) 73 | return nn.Sequential(*layers) 74 | 75 | def forward(self, x): 76 | out = self.conv(x) 77 | out = self.norm(out) 78 | out = self.relu(out) 79 | out = self.maxpool(out) 80 | out = self.layer1(out) 81 | out = self.layer2(out) 82 | out = self.layer3(out) 83 | out = self.layer4(out) 84 | out = self.avgpool(out) 85 | out = self.flatten(out) 86 | out = self.dense(out) 87 | return out 88 | 89 | 90 | class BasicBlock(nn.Module): 91 | expansion = 1 92 | 93 | def __init__(self, in_channel, out_channel, stride=1, downsample=None, norm=None): 94 | super(BasicBlock, self).__init__() 95 | norm = nn.BatchNorm2d 96 | 97 | self.conv1 = conv3x3(in_channel, out_channel, stride) 98 | self.norm1 = norm(out_channel) 99 | self.conv2 = conv3x3(out_channel, out_channel) 100 | self.norm2 = norm(out_channel) 101 | self.relu = nn.ReLU(inplace=True) 102 | self.downsample = downsample 103 | 104 | def forward(self, x): 105 | identity = x 106 | out = self.conv1(x) 107 | out = self.norm1(out) 108 | out = self.relu(out) 109 | out = self.conv2(out) 110 | out = self.norm2(out) 111 | if self.downsample is not None: 112 | identity = self.downsample(identity) 113 | out += identity 114 | out = self.relu(out) 115 | return out 116 | 117 | class Bottleneck(nn.Module): 118 | expansion = 4 119 | 120 | def __init__(self, in_channel, out_channel, stride=1, downsample=None, norm=None): 121 | super(Bottleneck, self).__init__() 122 | norm = nn.BatchNorm2d 123 | 124 | mid_channel= out_channel // self.expansion 125 | self.conv1 = conv1x1(in_channel, mid_channel) 126 | self.norm1 = norm(mid_channel) 127 | self.conv2 = conv3x3(mid_channel, mid_channel, stride) 128 | self.norm2 = norm(mid_channel) 129 | self.conv3 = conv1x1(mid_channel, out_channel) 130 | self.norm3 = norm(out_channel) 131 | self.relu = nn.ReLU(inplace=True) 132 | self.downsample = downsample 133 | 134 | def forward(self, x): 135 | identity = x 136 | out = self.conv1(x) 137 | out = self.norm1(out) 138 | out = self.relu(out) 139 | out = self.conv2(out) 140 | out = self.norm2(out) 141 | out = self.relu(out) 142 | out = self.conv3(out) 143 | out = self.norm3(out) 144 | if self.downsample is not None: 145 | identity = self.downsample(x) 146 | out += identity 147 | out = self.relu(out) 148 | return out 149 | 150 | 151 | class ResNet18(nn.Module): 152 | def __init__(self, shape, num_classes=2, checkpoint_dir='checkpoint', checkpoint_name='ResNet18', 153 | pretrained=False, pretrained_path=None, norm='batch', zero_init_residual=False): 154 | super(ResNet18, self).__init__() 155 | 156 | if len(shape) != 3: 157 | raise ValueError('Invalid shape: {}'.format(shape)) 158 | self.shape = shape 159 | self.num_classes = num_classes 160 | self.checkpoint_dir = checkpoint_dir 161 | self.checkpoint_name = checkpoint_name 162 | self.H, self.W, self.C = shape 163 | 164 | if not os.path.exists(checkpoint_dir): 165 | os.makedirs(checkpoint_dir) 166 | self.checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name, 'model.pt') 167 | 168 | model = ResNet(BasicBlock, [2,2,2, 2], num_classes, norm, zero_init_residual) 169 | 170 | if pretrained: 171 | print("Pretrained Weight is loaded!!") 172 | if pretrained_path is None: 173 | print("Loading from torchvision models") 174 | model = resnet.resnet18(pretrained=True) 175 | if zero_init_residual: 176 | for m in model.modules(): 177 | if isinstance(m, resnet.Bottleneck): 178 | nn.init.constant_(m.bn3.weight, 0) 179 | elif isinstance(m, BasicBlock): 180 | nn.init.constant_(m.bn2.weight, 0) 181 | else: 182 | checkpoint = torch.load(pretrained_path) 183 | model.load_state_dict(checkpoint) 184 | 185 | self.features = nn.Sequential(*list(model.children())[:-2]) 186 | self.num_features = 512 * BasicBlock.expansion 187 | self.classifier = nn.Sequential( 188 | nn.AdaptiveAvgPool2d((1, 1)), 189 | Flatten(), 190 | nn.Linear(self.num_features, num_classes) 191 | ) 192 | 193 | def save(self, checkpoint_name=''): 194 | if checkpoint_name == '': 195 | torch.save(self.state_dict(), self.checkpoint_path) 196 | else: 197 | checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name, checkpoint_name + '.pt') 198 | torch.save(self.state_dict(), checkpoint_path) 199 | 200 | def load(self, checkpoint_name=''): 201 | if checkpoint_name == '': 202 | self.load_state_dict(torch.load(self.checkpoint_path)) 203 | else: 204 | checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name, checkpoint_name + '.pt') 205 | self.load_state_dict(torch.load(checkpoint_path)) 206 | 207 | def forward(self, x): 208 | out = x 209 | out = self.features(out) 210 | out = self.classifier(out) 211 | return out 212 | 213 | class ResNet50(nn.Module): 214 | def __init__(self, shape, num_classes=2, checkpoint_dir='checkpoint', checkpoint_name='ResNet50', 215 | pretrained=False, pretrained_path=None, norm='batch', zero_init_residual=False): 216 | super(ResNet50, self).__init__() 217 | 218 | if len(shape) != 3: 219 | raise ValueError('Invalid shape: {}'.format(shape)) 220 | self.shape = shape 221 | self.num_classes = num_classes 222 | self.checkpoint_dir = checkpoint_dir 223 | self.checkpoint_name = checkpoint_name 224 | self.H, self.W, self.C = shape 225 | 226 | if not os.path.exists(checkpoint_dir): 227 | os.makedirs(checkpoint_dir) 228 | self.checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name, 'model.pt') 229 | 230 | model = ResNet(Bottleneck, [3,4,6,3], num_classes, norm, zero_init_residual) 231 | 232 | if pretrained: 233 | print("Pretrained Weight is loaded!!") 234 | if pretrained_path is None: 235 | print("Loading from torchvision models") 236 | model = resnet.resnet50(pretrained=True) 237 | if zero_init_residual: 238 | for m in model.modules(): 239 | if isinstance(m, resnet.Bottleneck): 240 | nn.init.constant_(m.bn3.weight, 0) 241 | else: 242 | checkpoint = torch.load(pretrained_path) 243 | model.load_state_dict(checkpoint) 244 | 245 | self.features = nn.Sequential(*list(model.children())[:-2]) 246 | self.num_features = 512 * Bottleneck.expansion 247 | self.classifier = nn.Sequential( 248 | nn.AdaptiveAvgPool2d((1, 1)), 249 | Flatten(), 250 | nn.Linear(self.num_features, num_classes) 251 | ) 252 | 253 | def save(self, checkpoint_name=''): 254 | if checkpoint_name == '': 255 | torch.save(self.state_dict(), self.checkpoint_path) 256 | else: 257 | checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name, checkpoint_name + '.pt') 258 | torch.save(self.state_dict(), checkpoint_path) 259 | 260 | def load(self, checkpoint_name=''): 261 | if checkpoint_name == '': 262 | self.load_state_dict(torch.load(self.checkpoint_path)) 263 | else: 264 | checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name, checkpoint_name + '.pt') 265 | self.load_state_dict(torch.load(checkpoint_path)) 266 | 267 | def forward(self, x): 268 | out = x 269 | out = self.features(out) 270 | out = self.classifier(out) 271 | return out -------------------------------------------------------------------------------- /option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args(): 4 | parser = argparse.ArgumentParser() 5 | 6 | # model architecture & checkpoint 7 | parser.add_argument('--model', default='ResNet18', choices=('ResNet18', 'ResNet50'), 8 | help='optimizer to use (ResNet18 | ResNet50)') 9 | parser.add_argument('--norm', default='batchnorm') 10 | parser.add_argument('--num_classes', type=int, default=6) 11 | parser.add_argument('--pretrained', type=int, default=1) 12 | parser.add_argument('--pretrained_path', type=str, default=None) 13 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint') 14 | parser.add_argument('--checkpoint_name', type=str, default='') 15 | 16 | # data loading 17 | parser.add_argument('--num_workers', type=int, default=16) 18 | parser.add_argument('--seed', type=int, default=42, help='random seed') 19 | 20 | # training hyper parameters 21 | parser.add_argument('--batch_size', type=int, default=256) 22 | parser.add_argument('--epochs', type=int, default=120) 23 | parser.add_argument('--log_interval', type=int, default=20) 24 | parser.add_argument('--evaluate', action='store_true', default=False) 25 | parser.add_argument('--amp', action='store_true', default=False) 26 | 27 | # optimzier & learning rate scheduler 28 | parser.add_argument('--learning_rate', type=float, default=0.0001) 29 | parser.add_argument('--weight_decay', type=float, default=0.0001) 30 | parser.add_argument('--optimizer', default='ADAM', choices=('SGD', 'ADAM'), 31 | help='optimizer to use (SGD | ADAM)') 32 | parser.add_argument('--decay_type', default='cosine_warmup', choices=('step', 'step_warmup', 'cosine_warmup'), 33 | help='optimizer to use (step | step_warmup | cosine_warmup)') 34 | 35 | args = parser.parse_args() 36 | return args 37 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cycler==0.10.0 2 | future==0.18.2 3 | joblib==0.15.1 4 | kiwisolver==1.2.0 5 | matplotlib==3.2.2 6 | numpy==1.19.0 7 | Pillow==7.1.2 8 | pyparsing==2.4.7 9 | python-dateutil==2.8.1 10 | scikit-learn==0.23.1 11 | scipy==1.5.0 12 | six==1.15.0 13 | threadpoolctl==2.1.0 14 | torch==1.6.0 15 | torchvision==0.7.0 16 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | import matplotlib.pyplot as plt 5 | from sklearn.model_selection import train_test_split 6 | 7 | import torch 8 | import torch.optim as optim 9 | import torch.optim.lr_scheduler as lrs 10 | from torch.utils.data.sampler import SubsetRandomSampler 11 | import torchvision 12 | from torchvision import transforms as T 13 | 14 | from network.resnet import * 15 | from learning.lr_scheduler import GradualWarmupScheduler 16 | 17 | def get_model(args, shape, num_classes): 18 | model = eval(args.model)( 19 | shape, 20 | num_classes, 21 | checkpoint_dir=args.checkpoint_dir, 22 | checkpoint_name=args.checkpoint_name, 23 | pretrained=args.pretrained, 24 | pretrained_path=args.pretrained_path, 25 | norm=args.norm, 26 | )#.cuda(args.gpu) 27 | return model 28 | 29 | 30 | def make_optimizer(args, model): 31 | trainable = filter(lambda x: x.requires_grad, model.parameters()) 32 | 33 | if args.optimizer == 'SGD': 34 | optimizer_function = optim.SGD 35 | kwargs = {'momentum': 0.9} 36 | elif args.optimizer == 'ADAM': 37 | optimizer_function = optim.Adam 38 | kwargs = { 39 | 'betas': (0.9, 0.999), 40 | 'eps': 1e-08 41 | } 42 | else: 43 | raise NameError('Not Supportes Optimizer') 44 | 45 | kwargs['lr'] = args.learning_rate 46 | kwargs['weight_decay'] = args.weight_decay 47 | 48 | return optimizer_function(trainable, **kwargs) 49 | 50 | 51 | def make_scheduler(args, optimizer): 52 | if args.decay_type == 'step': 53 | scheduler = lrs.MultiStepLR( 54 | optimizer, 55 | milestones=[30, 60, 90], 56 | gamma=0.1 57 | ) 58 | elif args.decay_type == 'step_warmup': 59 | scheduler = lrs.MultiStepLR( 60 | optimizer, 61 | milestones=[30, 60, 90], 62 | gamma=0.1 63 | ) 64 | scheduler = GradualWarmupScheduler( 65 | optimizer, 66 | multiplier=1, 67 | total_epoch=5, 68 | after_scheduler=scheduler 69 | ) 70 | elif args.decay_type == 'cosine_warmup': 71 | cosine_scheduler = lrs.CosineAnnealingLR( 72 | optimizer, 73 | T_max=args.epochs 74 | ) 75 | scheduler = GradualWarmupScheduler( 76 | optimizer, 77 | multiplier=1, 78 | total_epoch=args.epochs//10, 79 | after_scheduler=cosine_scheduler 80 | ) 81 | else: 82 | raise Exception('unknown lr scheduler: {}'.format(args.decay_type)) 83 | 84 | return scheduler 85 | 86 | def make_dataloader(args): 87 | 88 | train_trans = T.Compose([ 89 | T.Resize((256, 256)), 90 | T.RandomHorizontalFlip(), 91 | T.ToTensor(), 92 | ]) 93 | 94 | valid_trans = T.Compose([ 95 | T.Resize((256, 256)), 96 | T.ToTensor(), 97 | ]) 98 | 99 | test_trans = T.Compose([ 100 | T.Resize((256, 256)), 101 | T.ToTensor(), 102 | ]) 103 | 104 | trainset = torchvision.datasets.ImageFolder(root="data/seg_train/seg_train", transform=train_trans) 105 | validset = torchvision.datasets.ImageFolder(root="data/seg_train/seg_train", transform=valid_trans) 106 | testset = torchvision.datasets.ImageFolder(root="data/seg_test/seg_test", transform=test_trans) 107 | 108 | np.random.seed(args.seed) 109 | targets = trainset.targets 110 | train_idx, valid_idx = train_test_split(np.arange(len(targets)), test_size=0.2, shuffle=True, stratify=targets) 111 | train_sampler = SubsetRandomSampler(train_idx) 112 | valid_sampler = SubsetRandomSampler(valid_idx) 113 | 114 | train_loader = torch.utils.data.DataLoader( 115 | trainset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.num_workers 116 | ) 117 | 118 | valid_loader = torch.utils.data.DataLoader( 119 | validset, batch_size=args.batch_size, sampler=valid_sampler, num_workers=args.num_workers 120 | ) 121 | 122 | test_loader = torch.utils.data.DataLoader( 123 | testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers 124 | ) 125 | 126 | return train_loader, valid_loader, test_loader 127 | 128 | 129 | def plot_learning_curves(metrics, cur_epoch, args): 130 | x = np.arange(cur_epoch+1) 131 | fig, ax1 = plt.subplots() 132 | ax1.set_xlabel('epochs') 133 | ax1.set_ylabel('loss') 134 | ln1 = ax1.plot(x, metrics['train_loss'], color='tab:red') 135 | ln2 = ax1.plot(x, metrics['val_loss'], color='tab:red', linestyle='dashed') 136 | ax1.grid() 137 | ax2 = ax1.twinx() 138 | ax2.set_ylabel('accuracy') 139 | ln3 = ax2.plot(x, metrics['train_acc'], color='tab:blue') 140 | ln4 = ax2.plot(x, metrics['val_acc'], color='tab:blue', linestyle='dashed') 141 | lns = ln1+ln2+ln3+ln4 142 | plt.legend(lns, ['Train loss', 'Validation loss', 'Train accuracy','Validation accuracy']) 143 | plt.tight_layout() 144 | plt.savefig('{}/{}/learning_curve.png'.format(args.checkpoint_dir, args.checkpoint_name), bbox_inches='tight') 145 | plt.close('all') 146 | 147 | class AverageMeter(object): 148 | """Computes and stores the average and current value""" 149 | def __init__(self): 150 | self.reset() 151 | 152 | def reset(self): 153 | self.val = 0 154 | self.avg = 0 155 | self.sum = 0 156 | self.count = 0 157 | self.max = 0 158 | self.min = 1e5 159 | 160 | def update(self, val, n=1): 161 | self.val = val 162 | self.sum += val * n 163 | self.count += n 164 | self.avg = self.sum / self.count 165 | if val > self.max: 166 | self.max = val 167 | if val < self.min: 168 | self.min = val 169 | 170 | def accuracy(output, target, topk=(1,)): 171 | """Computes the precision@k for the specified values of k""" 172 | maxk = max(topk) 173 | batch_size = target.size(0) 174 | 175 | _, pred = output.topk(maxk, 1, True, True) 176 | pred = pred.t() 177 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 178 | 179 | res = [] 180 | for k in topk: 181 | correct_k = correct[:k].view(-1).float().sum(0) 182 | res.append(correct_k.mul_(100.0 / batch_size)) 183 | return res --------------------------------------------------------------------------------