├── LICENSE
├── README.md
├── assets
├── banner.PNG
└── data_folder.PNG
├── learning
├── evaluator.py
├── lr_scheduler.py
└── trainer.py
├── main.py
├── network
└── resnet.py
├── option.py
├── requirements.txt
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Hoseong Lee
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | # automatic-mixed-precision-tutorials-pytorch
7 | Automatic Mixed Precision Tutorials using pytorch. Based on [PyTorch 1.6 Official Features (Automatic Mixed Precision)](https://pytorch.org/docs/stable/notes/amp_examples.html), implement classification codebase using custom dataset.
8 |
9 | - author: hoya012
10 | - last update: 2020.09.03
11 | - [supplementary materials 1 (blog post written in Korean)](https://hoya012.github.io/blog/Mixed-Precision-Training/)
12 | - [supplementary materials 2 (blog post written in Korean)](https://hoya012.github.io/blog/Image-Classification-with-Mixed-Precision-Training-PyTorch-Tutorial/)
13 |
14 | ## 0. Experimental Setup (I used 1 GTX 1080 Ti GPU and 1 RTX 2080 Ti GPU!)
15 | ### 0-1. Prepare Library
16 | - Must use Newest PyTorch version. (>= 1.6.0)
17 |
18 | ```python
19 | pip install -r requirements.txt
20 | ```
21 |
22 | ### 0-2. Download dataset (Kaggle Intel Image Classification)
23 |
24 | - [Intel Image Classification](https://www.kaggle.com/puneet6060/intel-image-classification/)
25 |
26 | This Data contains around 25k images of size 150x150 distributed under 6 categories.
27 | {'buildings' -> 0,
28 | 'forest' -> 1,
29 | 'glacier' -> 2,
30 | 'mountain' -> 3,
31 | 'sea' -> 4,
32 | 'street' -> 5 }
33 |
34 | - Make `data` folder and move dataset into `data` folder.
35 |
36 |
37 |
38 |
39 |
40 | ### 1. Baseline Training
41 | - ImageNet Pretrained ResNet-18 from torchvision.models
42 | - Batch Size 256 / Epochs 120 / Initial Learning Rate 0.0001
43 | - Training Augmentation: Resize((256, 256)), RandomHorizontalFlip()
44 | - Adam + Cosine Learning rate scheduling with warmup
45 | - I tried NVIDIA Pascal GPU - GTX 1080 Ti 1 GPU (w/o Tensor Core) and NVIDIA Turing GPU - RTX 2080 Ti 1 GPU (with Tensor Core)
46 |
47 | ```python
48 | python main.py --checkpoint_name baseline;
49 | ```
50 |
51 | ### 2. Automatic Mixed Precision Training
52 |
53 | In PyTorch 1.6, Automatic Mixed Precision Training is very easy to use! Thanks to PyTorch!
54 |
55 | #### 2.1 Before
56 | ```python
57 | for batch_idx, (inputs, labels) in enumerate(data_loader):
58 | self.optimizer.zero_grad()
59 |
60 | outputs = self.model(inputs)
61 | loss = self.criterion(outputs, labels)
62 |
63 | loss.backward()
64 | self.optimizer.step()
65 | ```
66 | #### 2.2 After (just add 5 lines)
67 |
68 | ```python
69 | """ define loss scaler for automatic mixed precision """
70 | scaler = torch.cuda.amp.GradScaler()
71 |
72 | for batch_idx, (inputs, labels) in enumerate(data_loader):
73 | self.optimizer.zero_grad()
74 |
75 | with torch.cuda.amp.autocast():
76 | outputs = self.model(inputs)
77 | loss = self.criterion(outputs, labels)
78 |
79 | # Scales the loss, and calls backward()
80 | # to create scaled gradients
81 | self.scaler.scale(loss).backward()
82 |
83 | # Unscales gradients and calls
84 | # or skips optimizer.step()
85 | self.scaler.step(self.optimizer)
86 |
87 | # Updates the scale for next iteration
88 | self.scaler.update()
89 | ```
90 |
91 | #### 2.3 Run Script (Command Line)
92 | ```python
93 | python main.py --checkpoint_name baseline_amp --amp;
94 | ```
95 |
96 | ### 3. Performance Table
97 | - B : Baseline (FP32)
98 | - AMP : Automatic Mixed Precision Training (AMP)
99 |
100 | | Algorithm | Test Accuracy | GPU Memory | Total Training Time |
101 | |:------------:|:-------------:|:--------------:|:-------------------:|
102 | | B - 1080 Ti | 94.13 | 10737MB | 64.9m |
103 | | B - 2080 Ti | 94.17 | 10855MB | 54.3m |
104 | | AMP - 1080 Ti| 94.07 | 6615MB | 64.7m |
105 | | AMP - 2080 Ti| 94.23 | 7799MB | 37.3m |
106 |
107 | ### 4. Code Reference
108 | - Baseline Code: https://github.com/hoya012/carrier-of-tricks-for-classification-pytorch
109 | - Gradual Warmup Scheduler: https://github.com/ildoonet/pytorch-gradual-warmup-lr
110 | - PyTorch Automatic Mixed Precision: https://pytorch.org/docs/stable/notes/amp_examples.html
111 |
--------------------------------------------------------------------------------
/assets/banner.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoya012/automatic-mixed-precision-tutorials-pytorch/53f785ee07f5f09078014f7817b87487ea497eb2/assets/banner.PNG
--------------------------------------------------------------------------------
/assets/data_folder.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoya012/automatic-mixed-precision-tutorials-pytorch/53f785ee07f5f09078014f7817b87487ea497eb2/assets/data_folder.PNG
--------------------------------------------------------------------------------
/learning/evaluator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | from utils import AverageMeter, accuracy
5 |
6 | class Evaluator():
7 | def __init__(self, model, criterion):
8 | self.model = model
9 | self.criterion = criterion
10 | self.save_path = os.path.join(self.model.checkpoint_dir, self.model.checkpoint_name, 'result_dict.json')
11 | if not os.path.exists(os.path.join(self.model.checkpoint_dir, self.model.checkpoint_name)):
12 | os.makedirs(os.path.join(self.model.checkpoint_dir, self.model.checkpoint_name))
13 |
14 | def worst_result(self):
15 | ret = {
16 | 'loss': float('inf'),
17 | 'accuracy': 0.0
18 | }
19 | return ret
20 |
21 | def result_to_str(self, result):
22 | ret = [
23 | 'epoch: {epoch:0>3}',
24 | 'loss: {loss: >4.2e}'
25 | ]
26 | for metric in self.evaluation_metrics:
27 | ret.append('{}: {}'.format(metric.name, metric.fmtstr))
28 | return ', '.join(ret).format(**result)
29 |
30 | def save(self, result):
31 | with open(self.save_path, 'w') as f:
32 | f.write(json.dumps(result, sort_keys=True, indent=4, ensure_ascii=False))
33 |
34 | def load(self):
35 | result = self.worst_result
36 | if os.path.exists(self.save_path):
37 | with open(self.save_path, 'r') as f:
38 | try:
39 | result = json.loads(f.read())
40 | except:
41 | pass
42 | return result
43 |
44 | def evaluate(self, data_loader, epoch, args, result_dict):
45 | losses = AverageMeter()
46 | top1 = AverageMeter()
47 |
48 | self.model.eval()
49 | total_loss = 0
50 | with torch.no_grad():
51 | for batch_idx, (inputs, labels) in enumerate(data_loader):
52 | inputs, labels = inputs.cuda(), labels.cuda()
53 | if args.amp:
54 | with torch.cuda.amp.autocast():
55 | outputs = self.model(inputs)
56 | loss = self.criterion(outputs, labels)
57 | else:
58 | outputs = self.model(inputs)
59 | loss = self.criterion(outputs, labels)
60 |
61 | prec1, prec3 = accuracy(outputs.data, labels, topk=(1, 3))
62 | losses.update(loss.item(), inputs.size(0))
63 | top1.update(prec1.item(), inputs.size(0))
64 |
65 | print('----Validation Results Summary----')
66 | print('Epoch: [{}] Top-1 accuracy: {:.2f}%'.format(epoch, top1.avg))
67 |
68 | result_dict['val_loss'].append(losses.avg)
69 | result_dict['val_acc'].append(top1.avg)
70 |
71 | return result_dict
72 |
73 | def test(self, data_loader, args, result_dict):
74 | top1 = AverageMeter()
75 |
76 | self.model.eval()
77 | with torch.no_grad():
78 | for batch_idx, (inputs, labels) in enumerate(data_loader):
79 | inputs, labels = inputs.cuda(), labels.cuda()
80 | outputs = self.model(inputs)
81 |
82 | prec1, prec3 = accuracy(outputs.data, labels, topk=(1, 3))
83 | top1.update(prec1.item(), inputs.size(0))
84 |
85 | print('----Test Set Results Summary----')
86 | print('Top-1 accuracy: {:.2f}%'.format(top1.avg))
87 |
88 | result_dict['test_acc'].append(top1.avg)
89 |
90 | return result_dict
91 |
92 |
--------------------------------------------------------------------------------
/learning/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | from torch.optim.lr_scheduler import _LRScheduler
4 |
5 | """
6 | Gradually warm-up(increasing) learning rate in optimizer.
7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
8 |
9 | Code Reference: https://github.com/ildoonet/pytorch-gradual-warmup-lr
10 | """
11 |
12 | class GradualWarmupScheduler(_LRScheduler):
13 | """ Gradually warm-up(increasing) learning rate in optimizer.
14 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
15 | Args:
16 | optimizer (Optimizer): Wrapped optimizer.
17 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
18 | total_epoch: target learning rate is reached at total_epoch, gradually
19 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
20 | """
21 |
22 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
23 | self.multiplier = multiplier
24 | if self.multiplier < 1.:
25 | raise ValueError('multiplier should be greater thant or equal to 1.')
26 | self.total_epoch = total_epoch
27 | self.after_scheduler = after_scheduler
28 | self.finished = False
29 | super(GradualWarmupScheduler, self).__init__(optimizer)
30 |
31 | def get_lr(self):
32 | if self.last_epoch > self.total_epoch:
33 | if self.after_scheduler:
34 | if not self.finished:
35 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
36 | self.finished = True
37 | return self.after_scheduler.get_last_lr()
38 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
39 |
40 | if self.multiplier == 1.0:
41 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
42 | else:
43 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
44 |
45 | def step(self, epoch=None, metrics=None):
46 | if self.finished and self.after_scheduler:
47 | if epoch is None:
48 | self.after_scheduler.step(None)
49 | else:
50 | self.after_scheduler.step(epoch - self.total_epoch)
51 | self._last_lr = self.after_scheduler.get_last_lr()
52 | else:
53 | return super(GradualWarmupScheduler, self).step(epoch)
54 |
--------------------------------------------------------------------------------
/learning/trainer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from utils import AverageMeter, accuracy
4 |
5 | class Trainer:
6 | def __init__(self, model, criterion, optimizer, scheduler, scaler):
7 | self.model = model
8 | self.criterion = criterion
9 | self.optimizer = optimizer
10 | self.scheduler = scheduler
11 | self.scaler = scaler
12 |
13 | def train(self, data_loader, epoch, args, result_dict):
14 | total_loss = 0
15 | count = 0
16 |
17 | losses = AverageMeter()
18 | top1 = AverageMeter()
19 |
20 | self.model.train()
21 |
22 | for batch_idx, (inputs, labels) in enumerate(data_loader):
23 | inputs, labels = inputs.cuda(), labels.cuda()
24 |
25 | if args.amp:
26 | with torch.cuda.amp.autocast():
27 | outputs = self.model(inputs)
28 | loss = self.criterion(outputs, labels)
29 | else:
30 | outputs = self.model(inputs)
31 | loss = self.criterion(outputs, labels)
32 |
33 | if len(labels.size()) > 1:
34 | labels = torch.argmax(labels, axis=1)
35 |
36 | prec1, prec3 = accuracy(outputs.data, labels, topk=(1, 3))
37 | losses.update(loss.item(), inputs.size(0))
38 | top1.update(prec1.item(), inputs.size(0))
39 |
40 | self.optimizer.zero_grad()
41 |
42 | if args.amp:
43 | self.scaler.scale(loss).backward()
44 | self.scaler.step(self.optimizer)
45 | self.scaler.update()
46 | else:
47 | loss.backward()
48 | self.optimizer.step()
49 | total_loss += loss.tolist()
50 | count += labels.size(0)
51 |
52 | if batch_idx % args.log_interval == 0:
53 | _s = str(len(str(len(data_loader.sampler))))
54 | ret = [
55 | ('epoch: {:0>3} [{: >' + _s + '}/{} ({: >3.0f}%)]').format(epoch, count, len(data_loader.sampler), 100 * count / len(data_loader.sampler)),
56 | 'train_loss: {: >4.2e}'.format(total_loss / count),
57 | 'train_accuracy : {:.2f}%'.format(top1.avg)
58 | ]
59 | print(', '.join(ret))
60 |
61 | self.scheduler.step()
62 | result_dict['train_loss'].append(losses.avg)
63 | result_dict['train_acc'].append(top1.avg)
64 |
65 | return result_dict
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os, sys, time
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torchvision
6 |
7 | PATH = os.path.dirname(os.path.abspath(__file__))
8 | sys.path.insert(0, PATH + '/../..')
9 |
10 | from option import get_args
11 | from learning.trainer import Trainer
12 | from learning.evaluator import Evaluator
13 | from utils import get_model, make_optimizer, make_scheduler, make_dataloader, plot_learning_curves
14 |
15 | def main():
16 | args = get_args()
17 | torch.manual_seed(args.seed)
18 |
19 | shape = (224,224,3)
20 |
21 | """ define dataloader """
22 | train_loader, valid_loader, test_loader = make_dataloader(args)
23 |
24 | """ define model architecture """
25 | model = get_model(args, shape, args.num_classes)
26 |
27 | if torch.cuda.device_count() >= 1:
28 | print('Model pushed to {} GPU(s), type {}.'.format(torch.cuda.device_count(), torch.cuda.get_device_name(0)))
29 | model = model.cuda()
30 | else:
31 | raise ValueError('CPU training is not supported')
32 |
33 | """ define loss criterion """
34 | criterion = nn.CrossEntropyLoss().cuda()
35 |
36 | """ define optimizer """
37 | optimizer = make_optimizer(args, model)
38 |
39 | """ define learning rate scheduler """
40 | scheduler = make_scheduler(args, optimizer)
41 |
42 | """ define loss scaler for automatic mixed precision """
43 | scaler = torch.cuda.amp.GradScaler()
44 |
45 | """ define trainer, evaluator, result_dictionary """
46 | result_dict = {'args':vars(args), 'epoch':[], 'train_loss' : [], 'train_acc' : [], 'val_loss' : [], 'val_acc' : [], 'test_acc':[]}
47 | trainer = Trainer(model, criterion, optimizer, scheduler, scaler)
48 | evaluator = Evaluator(model, criterion)
49 |
50 | train_time_list = []
51 | valid_time_list = []
52 |
53 | if args.evaluate:
54 | """ load model checkpoint """
55 | model.load()
56 | result_dict = evaluator.test(test_loader, args, result_dict)
57 | else:
58 | evaluator.save(result_dict)
59 |
60 | best_val_acc = 0.0
61 | """ define training loop """
62 | for epoch in range(args.epochs):
63 | result_dict['epoch'] = epoch
64 |
65 | torch.cuda.synchronize()
66 | tic1 = time.time()
67 |
68 | result_dict = trainer.train(train_loader, epoch, args, result_dict)
69 |
70 | torch.cuda.synchronize()
71 | tic2 = time.time()
72 | train_time_list.append(tic2 - tic1)
73 |
74 | torch.cuda.synchronize()
75 | tic3 = time.time()
76 |
77 | result_dict = evaluator.evaluate(valid_loader, epoch, args, result_dict)
78 |
79 | torch.cuda.synchronize()
80 | tic4 = time.time()
81 | valid_time_list.append(tic4 - tic3)
82 |
83 | if result_dict['val_acc'][-1] > best_val_acc:
84 | print("{} epoch, best epoch was updated! {}%".format(epoch, result_dict['val_acc'][-1]))
85 | best_val_acc = result_dict['val_acc'][-1]
86 | model.save(checkpoint_name='best_model')
87 |
88 | evaluator.save(result_dict)
89 | plot_learning_curves(result_dict, epoch, args)
90 |
91 | result_dict = evaluator.test(test_loader, args, result_dict)
92 | evaluator.save(result_dict)
93 |
94 | """ calculate test accuracy using best model """
95 | model.load(checkpoint_name='best_model')
96 | result_dict = evaluator.test(test_loader, args, result_dict)
97 | evaluator.save(result_dict)
98 |
99 | print(result_dict)
100 |
101 | np.savetxt(os.path.join(model.checkpoint_dir, model.checkpoint_name, 'train_time_amp.csv'), train_time_list, delimiter=',', fmt='%s')
102 | np.savetxt(os.path.join(model.checkpoint_dir, model.checkpoint_name, 'valid_time_amp.csv'), valid_time_list, delimiter=',', fmt='%s')
103 |
104 | if __name__ == '__main__':
105 | main()
--------------------------------------------------------------------------------
/network/resnet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | from torchvision.models import resnet
5 |
6 | __all__ = [
7 | 'ResNet18', 'ResNet50'
8 | ]
9 |
10 | class Flatten(nn.Module):
11 | def __init__(self):
12 | super().__init__()
13 |
14 | def forward(self, x):
15 | return x.view(x.size(0), -1)
16 |
17 | def conv3x3(in_channel, out_channel, stride=1):
18 | return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
19 |
20 | def conv1x1(in_channel, out_channel, stride=1):
21 | return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, bias=False)
22 |
23 | class ResNet(nn.Module):
24 | def __init__(self, block, layer_config, num_classes=2, norm='batch', zero_init_residual=False):
25 | super(ResNet, self).__init__()
26 | norm = nn.BatchNorm2d
27 |
28 | self.in_channel = 64
29 |
30 | self.conv = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2, padding=3, bias=False)
31 | self.norm = norm(self.in_channel)
32 | self.relu = nn.ReLU(inplace=True)
33 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
34 | self.layer1 = self.make_layer(block, 64*block.expansion, layer_config[0], stride=1, norm=norm)
35 | self.layer2 = self.make_layer(block, 128*block.expansion, layer_config[1], stride=2, norm=norm)
36 | self.layer3 = self.make_layer(block, 256*block.expansion, layer_config[2], stride=2, norm=norm)
37 | self.layer4 = self.make_layer(block, 512*block.expansion, layer_config[3], stride=2, norm=norm)
38 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
39 | self.flatten = Flatten()
40 | self.dense = nn.Linear(512*block.expansion, num_classes)
41 |
42 | for m in self.modules():
43 | if isinstance(m, nn.Conv2d):
44 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
45 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
46 | nn.init.constant_(m.weight, 1)
47 | nn.init.constant_(m.bias, 0)
48 |
49 | # Zero-initialize the last BN in each residual branch,
50 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
51 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
52 | if zero_init_residual:
53 | for m in self.modules():
54 | if isinstance(m, BasicBlock):
55 | nn.init.constant_(m.norm2.weight, 0)
56 | elif isinstance(m, Bottleneck):
57 | nn.init.constant_(m.norm3.weight, 0)
58 |
59 | def make_layer(self, block, out_channel, num_blocks, stride=1, norm=None):
60 | norm = nn.BatchNorm2d
61 |
62 | downsample = None
63 | if stride != 1 or self.in_channel != out_channel:
64 | downsample = nn.Sequential(
65 | conv1x1(self.in_channel, out_channel, stride),
66 | norm(out_channel),
67 | )
68 | layers = []
69 | layers.append(block(self.in_channel, out_channel, stride, downsample, norm))
70 | self.in_channel = out_channel
71 | for _ in range(1, num_blocks):
72 | layers.append(block(self.in_channel, out_channel, norm=norm))
73 | return nn.Sequential(*layers)
74 |
75 | def forward(self, x):
76 | out = self.conv(x)
77 | out = self.norm(out)
78 | out = self.relu(out)
79 | out = self.maxpool(out)
80 | out = self.layer1(out)
81 | out = self.layer2(out)
82 | out = self.layer3(out)
83 | out = self.layer4(out)
84 | out = self.avgpool(out)
85 | out = self.flatten(out)
86 | out = self.dense(out)
87 | return out
88 |
89 |
90 | class BasicBlock(nn.Module):
91 | expansion = 1
92 |
93 | def __init__(self, in_channel, out_channel, stride=1, downsample=None, norm=None):
94 | super(BasicBlock, self).__init__()
95 | norm = nn.BatchNorm2d
96 |
97 | self.conv1 = conv3x3(in_channel, out_channel, stride)
98 | self.norm1 = norm(out_channel)
99 | self.conv2 = conv3x3(out_channel, out_channel)
100 | self.norm2 = norm(out_channel)
101 | self.relu = nn.ReLU(inplace=True)
102 | self.downsample = downsample
103 |
104 | def forward(self, x):
105 | identity = x
106 | out = self.conv1(x)
107 | out = self.norm1(out)
108 | out = self.relu(out)
109 | out = self.conv2(out)
110 | out = self.norm2(out)
111 | if self.downsample is not None:
112 | identity = self.downsample(identity)
113 | out += identity
114 | out = self.relu(out)
115 | return out
116 |
117 | class Bottleneck(nn.Module):
118 | expansion = 4
119 |
120 | def __init__(self, in_channel, out_channel, stride=1, downsample=None, norm=None):
121 | super(Bottleneck, self).__init__()
122 | norm = nn.BatchNorm2d
123 |
124 | mid_channel= out_channel // self.expansion
125 | self.conv1 = conv1x1(in_channel, mid_channel)
126 | self.norm1 = norm(mid_channel)
127 | self.conv2 = conv3x3(mid_channel, mid_channel, stride)
128 | self.norm2 = norm(mid_channel)
129 | self.conv3 = conv1x1(mid_channel, out_channel)
130 | self.norm3 = norm(out_channel)
131 | self.relu = nn.ReLU(inplace=True)
132 | self.downsample = downsample
133 |
134 | def forward(self, x):
135 | identity = x
136 | out = self.conv1(x)
137 | out = self.norm1(out)
138 | out = self.relu(out)
139 | out = self.conv2(out)
140 | out = self.norm2(out)
141 | out = self.relu(out)
142 | out = self.conv3(out)
143 | out = self.norm3(out)
144 | if self.downsample is not None:
145 | identity = self.downsample(x)
146 | out += identity
147 | out = self.relu(out)
148 | return out
149 |
150 |
151 | class ResNet18(nn.Module):
152 | def __init__(self, shape, num_classes=2, checkpoint_dir='checkpoint', checkpoint_name='ResNet18',
153 | pretrained=False, pretrained_path=None, norm='batch', zero_init_residual=False):
154 | super(ResNet18, self).__init__()
155 |
156 | if len(shape) != 3:
157 | raise ValueError('Invalid shape: {}'.format(shape))
158 | self.shape = shape
159 | self.num_classes = num_classes
160 | self.checkpoint_dir = checkpoint_dir
161 | self.checkpoint_name = checkpoint_name
162 | self.H, self.W, self.C = shape
163 |
164 | if not os.path.exists(checkpoint_dir):
165 | os.makedirs(checkpoint_dir)
166 | self.checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name, 'model.pt')
167 |
168 | model = ResNet(BasicBlock, [2,2,2, 2], num_classes, norm, zero_init_residual)
169 |
170 | if pretrained:
171 | print("Pretrained Weight is loaded!!")
172 | if pretrained_path is None:
173 | print("Loading from torchvision models")
174 | model = resnet.resnet18(pretrained=True)
175 | if zero_init_residual:
176 | for m in model.modules():
177 | if isinstance(m, resnet.Bottleneck):
178 | nn.init.constant_(m.bn3.weight, 0)
179 | elif isinstance(m, BasicBlock):
180 | nn.init.constant_(m.bn2.weight, 0)
181 | else:
182 | checkpoint = torch.load(pretrained_path)
183 | model.load_state_dict(checkpoint)
184 |
185 | self.features = nn.Sequential(*list(model.children())[:-2])
186 | self.num_features = 512 * BasicBlock.expansion
187 | self.classifier = nn.Sequential(
188 | nn.AdaptiveAvgPool2d((1, 1)),
189 | Flatten(),
190 | nn.Linear(self.num_features, num_classes)
191 | )
192 |
193 | def save(self, checkpoint_name=''):
194 | if checkpoint_name == '':
195 | torch.save(self.state_dict(), self.checkpoint_path)
196 | else:
197 | checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name, checkpoint_name + '.pt')
198 | torch.save(self.state_dict(), checkpoint_path)
199 |
200 | def load(self, checkpoint_name=''):
201 | if checkpoint_name == '':
202 | self.load_state_dict(torch.load(self.checkpoint_path))
203 | else:
204 | checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name, checkpoint_name + '.pt')
205 | self.load_state_dict(torch.load(checkpoint_path))
206 |
207 | def forward(self, x):
208 | out = x
209 | out = self.features(out)
210 | out = self.classifier(out)
211 | return out
212 |
213 | class ResNet50(nn.Module):
214 | def __init__(self, shape, num_classes=2, checkpoint_dir='checkpoint', checkpoint_name='ResNet50',
215 | pretrained=False, pretrained_path=None, norm='batch', zero_init_residual=False):
216 | super(ResNet50, self).__init__()
217 |
218 | if len(shape) != 3:
219 | raise ValueError('Invalid shape: {}'.format(shape))
220 | self.shape = shape
221 | self.num_classes = num_classes
222 | self.checkpoint_dir = checkpoint_dir
223 | self.checkpoint_name = checkpoint_name
224 | self.H, self.W, self.C = shape
225 |
226 | if not os.path.exists(checkpoint_dir):
227 | os.makedirs(checkpoint_dir)
228 | self.checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name, 'model.pt')
229 |
230 | model = ResNet(Bottleneck, [3,4,6,3], num_classes, norm, zero_init_residual)
231 |
232 | if pretrained:
233 | print("Pretrained Weight is loaded!!")
234 | if pretrained_path is None:
235 | print("Loading from torchvision models")
236 | model = resnet.resnet50(pretrained=True)
237 | if zero_init_residual:
238 | for m in model.modules():
239 | if isinstance(m, resnet.Bottleneck):
240 | nn.init.constant_(m.bn3.weight, 0)
241 | else:
242 | checkpoint = torch.load(pretrained_path)
243 | model.load_state_dict(checkpoint)
244 |
245 | self.features = nn.Sequential(*list(model.children())[:-2])
246 | self.num_features = 512 * Bottleneck.expansion
247 | self.classifier = nn.Sequential(
248 | nn.AdaptiveAvgPool2d((1, 1)),
249 | Flatten(),
250 | nn.Linear(self.num_features, num_classes)
251 | )
252 |
253 | def save(self, checkpoint_name=''):
254 | if checkpoint_name == '':
255 | torch.save(self.state_dict(), self.checkpoint_path)
256 | else:
257 | checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name, checkpoint_name + '.pt')
258 | torch.save(self.state_dict(), checkpoint_path)
259 |
260 | def load(self, checkpoint_name=''):
261 | if checkpoint_name == '':
262 | self.load_state_dict(torch.load(self.checkpoint_path))
263 | else:
264 | checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name, checkpoint_name + '.pt')
265 | self.load_state_dict(torch.load(checkpoint_path))
266 |
267 | def forward(self, x):
268 | out = x
269 | out = self.features(out)
270 | out = self.classifier(out)
271 | return out
--------------------------------------------------------------------------------
/option.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def get_args():
4 | parser = argparse.ArgumentParser()
5 |
6 | # model architecture & checkpoint
7 | parser.add_argument('--model', default='ResNet18', choices=('ResNet18', 'ResNet50'),
8 | help='optimizer to use (ResNet18 | ResNet50)')
9 | parser.add_argument('--norm', default='batchnorm')
10 | parser.add_argument('--num_classes', type=int, default=6)
11 | parser.add_argument('--pretrained', type=int, default=1)
12 | parser.add_argument('--pretrained_path', type=str, default=None)
13 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint')
14 | parser.add_argument('--checkpoint_name', type=str, default='')
15 |
16 | # data loading
17 | parser.add_argument('--num_workers', type=int, default=16)
18 | parser.add_argument('--seed', type=int, default=42, help='random seed')
19 |
20 | # training hyper parameters
21 | parser.add_argument('--batch_size', type=int, default=256)
22 | parser.add_argument('--epochs', type=int, default=120)
23 | parser.add_argument('--log_interval', type=int, default=20)
24 | parser.add_argument('--evaluate', action='store_true', default=False)
25 | parser.add_argument('--amp', action='store_true', default=False)
26 |
27 | # optimzier & learning rate scheduler
28 | parser.add_argument('--learning_rate', type=float, default=0.0001)
29 | parser.add_argument('--weight_decay', type=float, default=0.0001)
30 | parser.add_argument('--optimizer', default='ADAM', choices=('SGD', 'ADAM'),
31 | help='optimizer to use (SGD | ADAM)')
32 | parser.add_argument('--decay_type', default='cosine_warmup', choices=('step', 'step_warmup', 'cosine_warmup'),
33 | help='optimizer to use (step | step_warmup | cosine_warmup)')
34 |
35 | args = parser.parse_args()
36 | return args
37 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cycler==0.10.0
2 | future==0.18.2
3 | joblib==0.15.1
4 | kiwisolver==1.2.0
5 | matplotlib==3.2.2
6 | numpy==1.19.0
7 | Pillow==7.1.2
8 | pyparsing==2.4.7
9 | python-dateutil==2.8.1
10 | scikit-learn==0.23.1
11 | scipy==1.5.0
12 | six==1.15.0
13 | threadpoolctl==2.1.0
14 | torch==1.6.0
15 | torchvision==0.7.0
16 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib
3 | matplotlib.use('Agg')
4 | import matplotlib.pyplot as plt
5 | from sklearn.model_selection import train_test_split
6 |
7 | import torch
8 | import torch.optim as optim
9 | import torch.optim.lr_scheduler as lrs
10 | from torch.utils.data.sampler import SubsetRandomSampler
11 | import torchvision
12 | from torchvision import transforms as T
13 |
14 | from network.resnet import *
15 | from learning.lr_scheduler import GradualWarmupScheduler
16 |
17 | def get_model(args, shape, num_classes):
18 | model = eval(args.model)(
19 | shape,
20 | num_classes,
21 | checkpoint_dir=args.checkpoint_dir,
22 | checkpoint_name=args.checkpoint_name,
23 | pretrained=args.pretrained,
24 | pretrained_path=args.pretrained_path,
25 | norm=args.norm,
26 | )#.cuda(args.gpu)
27 | return model
28 |
29 |
30 | def make_optimizer(args, model):
31 | trainable = filter(lambda x: x.requires_grad, model.parameters())
32 |
33 | if args.optimizer == 'SGD':
34 | optimizer_function = optim.SGD
35 | kwargs = {'momentum': 0.9}
36 | elif args.optimizer == 'ADAM':
37 | optimizer_function = optim.Adam
38 | kwargs = {
39 | 'betas': (0.9, 0.999),
40 | 'eps': 1e-08
41 | }
42 | else:
43 | raise NameError('Not Supportes Optimizer')
44 |
45 | kwargs['lr'] = args.learning_rate
46 | kwargs['weight_decay'] = args.weight_decay
47 |
48 | return optimizer_function(trainable, **kwargs)
49 |
50 |
51 | def make_scheduler(args, optimizer):
52 | if args.decay_type == 'step':
53 | scheduler = lrs.MultiStepLR(
54 | optimizer,
55 | milestones=[30, 60, 90],
56 | gamma=0.1
57 | )
58 | elif args.decay_type == 'step_warmup':
59 | scheduler = lrs.MultiStepLR(
60 | optimizer,
61 | milestones=[30, 60, 90],
62 | gamma=0.1
63 | )
64 | scheduler = GradualWarmupScheduler(
65 | optimizer,
66 | multiplier=1,
67 | total_epoch=5,
68 | after_scheduler=scheduler
69 | )
70 | elif args.decay_type == 'cosine_warmup':
71 | cosine_scheduler = lrs.CosineAnnealingLR(
72 | optimizer,
73 | T_max=args.epochs
74 | )
75 | scheduler = GradualWarmupScheduler(
76 | optimizer,
77 | multiplier=1,
78 | total_epoch=args.epochs//10,
79 | after_scheduler=cosine_scheduler
80 | )
81 | else:
82 | raise Exception('unknown lr scheduler: {}'.format(args.decay_type))
83 |
84 | return scheduler
85 |
86 | def make_dataloader(args):
87 |
88 | train_trans = T.Compose([
89 | T.Resize((256, 256)),
90 | T.RandomHorizontalFlip(),
91 | T.ToTensor(),
92 | ])
93 |
94 | valid_trans = T.Compose([
95 | T.Resize((256, 256)),
96 | T.ToTensor(),
97 | ])
98 |
99 | test_trans = T.Compose([
100 | T.Resize((256, 256)),
101 | T.ToTensor(),
102 | ])
103 |
104 | trainset = torchvision.datasets.ImageFolder(root="data/seg_train/seg_train", transform=train_trans)
105 | validset = torchvision.datasets.ImageFolder(root="data/seg_train/seg_train", transform=valid_trans)
106 | testset = torchvision.datasets.ImageFolder(root="data/seg_test/seg_test", transform=test_trans)
107 |
108 | np.random.seed(args.seed)
109 | targets = trainset.targets
110 | train_idx, valid_idx = train_test_split(np.arange(len(targets)), test_size=0.2, shuffle=True, stratify=targets)
111 | train_sampler = SubsetRandomSampler(train_idx)
112 | valid_sampler = SubsetRandomSampler(valid_idx)
113 |
114 | train_loader = torch.utils.data.DataLoader(
115 | trainset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.num_workers
116 | )
117 |
118 | valid_loader = torch.utils.data.DataLoader(
119 | validset, batch_size=args.batch_size, sampler=valid_sampler, num_workers=args.num_workers
120 | )
121 |
122 | test_loader = torch.utils.data.DataLoader(
123 | testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers
124 | )
125 |
126 | return train_loader, valid_loader, test_loader
127 |
128 |
129 | def plot_learning_curves(metrics, cur_epoch, args):
130 | x = np.arange(cur_epoch+1)
131 | fig, ax1 = plt.subplots()
132 | ax1.set_xlabel('epochs')
133 | ax1.set_ylabel('loss')
134 | ln1 = ax1.plot(x, metrics['train_loss'], color='tab:red')
135 | ln2 = ax1.plot(x, metrics['val_loss'], color='tab:red', linestyle='dashed')
136 | ax1.grid()
137 | ax2 = ax1.twinx()
138 | ax2.set_ylabel('accuracy')
139 | ln3 = ax2.plot(x, metrics['train_acc'], color='tab:blue')
140 | ln4 = ax2.plot(x, metrics['val_acc'], color='tab:blue', linestyle='dashed')
141 | lns = ln1+ln2+ln3+ln4
142 | plt.legend(lns, ['Train loss', 'Validation loss', 'Train accuracy','Validation accuracy'])
143 | plt.tight_layout()
144 | plt.savefig('{}/{}/learning_curve.png'.format(args.checkpoint_dir, args.checkpoint_name), bbox_inches='tight')
145 | plt.close('all')
146 |
147 | class AverageMeter(object):
148 | """Computes and stores the average and current value"""
149 | def __init__(self):
150 | self.reset()
151 |
152 | def reset(self):
153 | self.val = 0
154 | self.avg = 0
155 | self.sum = 0
156 | self.count = 0
157 | self.max = 0
158 | self.min = 1e5
159 |
160 | def update(self, val, n=1):
161 | self.val = val
162 | self.sum += val * n
163 | self.count += n
164 | self.avg = self.sum / self.count
165 | if val > self.max:
166 | self.max = val
167 | if val < self.min:
168 | self.min = val
169 |
170 | def accuracy(output, target, topk=(1,)):
171 | """Computes the precision@k for the specified values of k"""
172 | maxk = max(topk)
173 | batch_size = target.size(0)
174 |
175 | _, pred = output.topk(maxk, 1, True, True)
176 | pred = pred.t()
177 | correct = pred.eq(target.view(1, -1).expand_as(pred))
178 |
179 | res = []
180 | for k in topk:
181 | correct_k = correct[:k].view(-1).float().sum(0)
182 | res.append(correct_k.mul_(100.0 / batch_size))
183 | return res
--------------------------------------------------------------------------------