├── LICENSE
├── README.md
├── assets
└── data_folder.PNG
├── learning
├── evaluator.py
├── lr_scheduler.py
└── trainer.py
├── main.py
├── network
└── resnet.py
├── option.py
├── requirements.txt
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Hoseong Lee
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # swa-tutorials-pytorch
3 | [Stochastic Weight Averaging](https://arxiv.org/abs/1803.05407) Tutorials using pytorch. Based on [PyTorch 1.6 Official Features (Stochastic Weight Averaging)](https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/), implement classification codebase using custom dataset.
4 |
5 | - author: hoya012
6 | - last update: 2020.10.23
7 |
8 | ## 0. Experimental Setup
9 | ### 0-1. Prepare Library
10 | - Need to install PyTorch and Captum
11 |
12 | ```python
13 | pip install -r requirements.txt
14 | ```
15 |
16 | ### 0-2. Download dataset (Kaggle Intel Image Classification)
17 |
18 | - [Intel Image Classification](https://www.kaggle.com/puneet6060/intel-image-classification/)
19 |
20 | This Data contains around 25k images of size 150x150 distributed under 6 categories.
21 | {'buildings' -> 0,
22 | 'forest' -> 1,
23 | 'glacier' -> 2,
24 | 'mountain' -> 3,
25 | 'sea' -> 4,
26 | 'street' -> 5 }
27 |
28 | - Make `data` folder and move dataset into `data` folder.
29 |
30 |
31 |
32 |
33 |
34 | ### 1. Baseline Training
35 | - ImageNet Pretrained ResNet-18 from torchvision.models
36 | - Batch Size 256 / Epochs 120 / Initial Learning Rate 0.0001
37 | - Training Augmentation: Resize((256, 256)), RandomHorizontalFlip()
38 | - Adam + Cosine Learning rate scheduling with warmup
39 | - I tried NVIDIA Pascal GPU - GTX 1080 Ti 1 GPU
40 |
41 | ```python
42 | python main.py --checkpoint_name baseline;
43 | ```
44 |
45 | ### 2. Stochastic Weight Averaging Training
46 |
47 | In PyTorch 1.6, Stochastic Weight Averaging is very easy to use! Thanks to PyTorch..
48 |
49 | - PyTorch's official tutorial's guide
50 | ```python
51 | from torch.optim.swa_utils import AveragedModel, SWALR
52 | from torch.optim.lr_scheduler import CosineAnnealingLR
53 |
54 | loader, optimizer, model, loss_fn = ...
55 | swa_model = AveragedModel(model)
56 | scheduler = CosineAnnealingLR(optimizer, T_max=100)
57 | swa_start = 5
58 | swa_scheduler = SWALR(optimizer, swa_lr=0.05)
59 |
60 | for epoch in range(100):
61 | for input, target in loader:
62 | optimizer.zero_grad()
63 | loss_fn(model(input), target).backward()
64 | optimizer.step()
65 | if epoch > swa_start:
66 | swa_model.update_parameters(model)
67 | swa_scheduler.step()
68 | else:
69 | scheduler.step()
70 |
71 | # Update bn statistics for the swa_model at the end
72 | torch.optim.swa_utils.update_bn(loader, swa_model)
73 | # Use swa_model to make predictions on test data
74 | preds = swa_model(test_input)
75 | ```
76 |
77 | - My own implementations
78 | ```python
79 | # in main.py
80 | """ define model and learning rate scheduler for stochastic weight averaging """
81 | swa_model = torch.optim.swa_utils.AveragedModel(model)
82 | swa_scheduler = SWALR(optimizer, swa_lr=args.swa_lr)
83 |
84 | ...
85 |
86 | # in learning/trainer.py
87 | for batch_idx, (inputs, labels) in enumerate(data_loader):
88 | if not args.decay_type == 'swa':
89 | self.scheduler.step()
90 | else:
91 | if epoch <= args.swa_start:
92 | self.scheduler.step()
93 |
94 | if epoch > args.swa_start and args.decay_type == 'swa':
95 | self.swa_model.update_parameters(self.model)
96 | self.swa_scheduler.step()
97 |
98 | ...
99 |
100 | # in main.py
101 | swa_model = swa_model.cpu()
102 | torch.optim.swa_utils.update_bn(train_loader, swa_model)
103 | swa_model = swa_model.cuda()
104 | ```
105 |
106 | #### Run Script (Command Line)
107 | ```python
108 | python main.py --checkpoint_name swa --decay_type swa --swa_start 90 --swa_lr 5e-5;
109 | ```
110 |
111 | ### 3. Performance Table
112 | - B : Baseline
113 | - SWA : Stochastic Weight Averaging
114 | - SWA_{swa_start}_{swa_lr}
115 |
116 | | Algorithm | Test Accuracy |
117 | |:------------:|:-------------:|
118 | | B | 94.10 |
119 | | SWA_90_0.05 | 80.53 |
120 | | SWA_90_1e-4 | 94.20 |
121 | | SWA_90_5e-4 | 93.87 |
122 | | SWA_90_1e-5 | 94.23 |
123 | | SWA_90_5e-5 | **94.57** |
124 | | SWA_75_5e-5 | 94.27 |
125 | | SWA_60_5e-5 | 94.33 |
126 |
127 | ### 4. Code Reference
128 | - Baseline Code: https://github.com/hoya012/carrier-of-tricks-for-classification-pytorch
129 | - Gradual Warmup Scheduler: https://github.com/ildoonet/pytorch-gradual-warmup-lr
130 | - PyTorch Stochastic Weight Averaging: https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging
--------------------------------------------------------------------------------
/assets/data_folder.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hoya012/swa-tutorials-pytorch/ab8cace9a1457a74cadfe6671598237a4d389cc0/assets/data_folder.PNG
--------------------------------------------------------------------------------
/learning/evaluator.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | from utils import AverageMeter, accuracy
5 |
6 | class Evaluator():
7 | def __init__(self, model, criterion, swa_model=None):
8 | self.model = model
9 | self.criterion = criterion
10 | self.swa_model = swa_model
11 | self.save_path = os.path.join(self.model.checkpoint_dir, self.model.checkpoint_name, 'result_dict.json')
12 | if not os.path.exists(os.path.join(self.model.checkpoint_dir, self.model.checkpoint_name)):
13 | os.makedirs(os.path.join(self.model.checkpoint_dir, self.model.checkpoint_name))
14 |
15 | def worst_result(self):
16 | ret = {
17 | 'loss': float('inf'),
18 | 'accuracy': 0.0
19 | }
20 | return ret
21 |
22 | def result_to_str(self, result):
23 | ret = [
24 | 'epoch: {epoch:0>3}',
25 | 'loss: {loss: >4.2e}'
26 | ]
27 | for metric in self.evaluation_metrics:
28 | ret.append('{}: {}'.format(metric.name, metric.fmtstr))
29 | return ', '.join(ret).format(**result)
30 |
31 | def save(self, result):
32 | with open(self.save_path, 'w') as f:
33 | f.write(json.dumps(result, sort_keys=True, indent=4, ensure_ascii=False))
34 |
35 | def load(self):
36 | result = self.worst_result
37 | if os.path.exists(self.save_path):
38 | with open(self.save_path, 'r') as f:
39 | try:
40 | result = json.loads(f.read())
41 | except:
42 | pass
43 | return result
44 |
45 | def evaluate(self, data_loader, epoch, args, result_dict):
46 | losses = AverageMeter()
47 | top1 = AverageMeter()
48 |
49 | self.model.eval()
50 |
51 | total_loss = 0
52 | with torch.no_grad():
53 | for batch_idx, (inputs, labels) in enumerate(data_loader):
54 | inputs, labels = inputs.cuda(), labels.cuda()
55 | if args.amp:
56 | with torch.cuda.amp.autocast():
57 | outputs = self.model(inputs)
58 | loss = self.criterion(outputs, labels)
59 | else:
60 | outputs = self.model(inputs)
61 | loss = self.criterion(outputs, labels)
62 |
63 | prec1, prec3 = accuracy(outputs.data, labels, topk=(1, 3))
64 | losses.update(loss.item(), inputs.size(0))
65 | top1.update(prec1.item(), inputs.size(0))
66 |
67 | print('----Validation Results Summary----')
68 | print('Epoch: [{}] Top-1 accuracy: {:.2f}%'.format(epoch, top1.avg))
69 |
70 | result_dict['val_loss'].append(losses.avg)
71 | result_dict['val_acc'].append(top1.avg)
72 |
73 | return result_dict
74 |
75 | def test(self, data_loader, args, result_dict):
76 | top1 = AverageMeter()
77 |
78 | if args.decay_type == 'swa':
79 | self.model = self.swa_model
80 |
81 | self.model.eval()
82 | with torch.no_grad():
83 | for batch_idx, (inputs, labels) in enumerate(data_loader):
84 | inputs, labels = inputs.cuda(), labels.cuda()
85 | outputs = self.model(inputs)
86 |
87 | prec1, prec3 = accuracy(outputs.data, labels, topk=(1, 3))
88 | top1.update(prec1.item(), inputs.size(0))
89 |
90 | print('----Test Set Results Summary----')
91 | print('Top-1 accuracy: {:.2f}%'.format(top1.avg))
92 |
93 | result_dict['test_acc'].append(top1.avg)
94 |
95 | return result_dict
96 |
97 |
--------------------------------------------------------------------------------
/learning/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | from torch.optim.lr_scheduler import _LRScheduler
4 |
5 | """
6 | Gradually warm-up(increasing) learning rate in optimizer.
7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
8 |
9 | Code Reference: https://github.com/ildoonet/pytorch-gradual-warmup-lr
10 | """
11 |
12 | class GradualWarmupScheduler(_LRScheduler):
13 | """ Gradually warm-up(increasing) learning rate in optimizer.
14 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
15 | Args:
16 | optimizer (Optimizer): Wrapped optimizer.
17 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
18 | total_epoch: target learning rate is reached at total_epoch, gradually
19 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
20 | """
21 |
22 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
23 | self.multiplier = multiplier
24 | if self.multiplier < 1.:
25 | raise ValueError('multiplier should be greater thant or equal to 1.')
26 | self.total_epoch = total_epoch
27 | self.after_scheduler = after_scheduler
28 | self.finished = False
29 | super(GradualWarmupScheduler, self).__init__(optimizer)
30 |
31 | def get_lr(self):
32 | if self.last_epoch > self.total_epoch:
33 | if self.after_scheduler:
34 | if not self.finished:
35 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
36 | self.finished = True
37 | return self.after_scheduler.get_last_lr()
38 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
39 |
40 | if self.multiplier == 1.0:
41 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
42 | else:
43 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
44 |
45 | def step(self, epoch=None, metrics=None):
46 | if self.finished and self.after_scheduler:
47 | if epoch is None:
48 | self.after_scheduler.step(None)
49 | else:
50 | self.after_scheduler.step(epoch - self.total_epoch)
51 | self._last_lr = self.after_scheduler.get_last_lr()
52 | else:
53 | return super(GradualWarmupScheduler, self).step(epoch)
54 |
--------------------------------------------------------------------------------
/learning/trainer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from utils import AverageMeter, accuracy
4 |
5 | class Trainer:
6 | def __init__(self, model, criterion, optimizer, scheduler, scaler, swa_model=None, swa_scheduler=None):
7 | self.model = model
8 | self.criterion = criterion
9 | self.optimizer = optimizer
10 | self.scheduler = scheduler
11 | self.scaler = scaler
12 | self.swa_model = swa_model
13 | self.swa_scheduler = swa_scheduler
14 |
15 | def train(self, data_loader, epoch, args, result_dict):
16 | total_loss = 0
17 | count = 0
18 |
19 | losses = AverageMeter()
20 | top1 = AverageMeter()
21 |
22 | self.model.train()
23 |
24 | for batch_idx, (inputs, labels) in enumerate(data_loader):
25 | inputs, labels = inputs.cuda(), labels.cuda()
26 |
27 | if args.amp:
28 | with torch.cuda.amp.autocast():
29 | outputs = self.model(inputs)
30 | loss = self.criterion(outputs, labels)
31 | else:
32 | outputs = self.model(inputs)
33 | loss = self.criterion(outputs, labels)
34 |
35 | if len(labels.size()) > 1:
36 | labels = torch.argmax(labels, axis=1)
37 |
38 | prec1, prec3 = accuracy(outputs.data, labels, topk=(1, 3))
39 | losses.update(loss.item(), inputs.size(0))
40 | top1.update(prec1.item(), inputs.size(0))
41 |
42 | self.optimizer.zero_grad()
43 |
44 | if args.amp:
45 | self.scaler.scale(loss).backward()
46 | self.scaler.step(self.optimizer)
47 | self.scaler.update()
48 | else:
49 | loss.backward()
50 | self.optimizer.step()
51 |
52 |
53 | total_loss += loss.tolist()
54 | count += labels.size(0)
55 |
56 | if batch_idx % args.log_interval == 0:
57 | _s = str(len(str(len(data_loader.sampler))))
58 | ret = [
59 | ('epoch: {:0>3} [{: >' + _s + '}/{} ({: >3.0f}%)]').format(epoch, count, len(data_loader.sampler), 100 * count / len(data_loader.sampler)),
60 | 'train_loss: {: >4.2e}'.format(total_loss / count),
61 | 'train_accuracy : {:.2f}%'.format(top1.avg)
62 | ]
63 | print(', '.join(ret))
64 |
65 | if not args.decay_type == 'swa':
66 | self.scheduler.step()
67 | else:
68 | if epoch <= args.swa_start:
69 | self.scheduler.step()
70 |
71 | if epoch > args.swa_start and args.decay_type == 'swa':
72 | self.swa_model.update_parameters(self.model)
73 | self.swa_scheduler.step()
74 |
75 | result_dict['train_loss'].append(losses.avg)
76 | result_dict['train_acc'].append(top1.avg)
77 |
78 | return result_dict
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os, sys, time
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torchvision
6 | from torch.optim.swa_utils import SWALR
7 |
8 | PATH = os.path.dirname(os.path.abspath(__file__))
9 | sys.path.insert(0, PATH + '/../..')
10 |
11 | from option import get_args
12 | from learning.trainer import Trainer
13 | from learning.evaluator import Evaluator
14 | from utils import get_model, make_optimizer, make_scheduler, make_dataloader, plot_learning_curves
15 |
16 | def main():
17 | args = get_args()
18 | torch.manual_seed(args.seed)
19 | torch.backends.cudnn.benchmark = True
20 |
21 | shape = (224,224,3)
22 |
23 | """ define dataloader """
24 | train_loader, valid_loader, test_loader = make_dataloader(args)
25 |
26 | """ define model architecture """
27 | model = get_model(args, shape, args.num_classes)
28 |
29 | if torch.cuda.device_count() >= 1:
30 | print('Model pushed to {} GPU(s), type {}.'.format(torch.cuda.device_count(), torch.cuda.get_device_name(0)))
31 | model = model.cuda()
32 | else:
33 | raise ValueError('CPU training is not supported')
34 |
35 | """ define loss criterion """
36 | criterion = nn.CrossEntropyLoss().cuda()
37 |
38 | """ define optimizer """
39 | optimizer = make_optimizer(args, model)
40 |
41 | """ define learning rate scheduler """
42 | scheduler = make_scheduler(args, optimizer)
43 |
44 | """ define loss scaler for automatic mixed precision """
45 | scaler = torch.cuda.amp.GradScaler()
46 |
47 | """ define model and learning rate scheduler for stochastic weight averaging """
48 | swa_model = torch.optim.swa_utils.AveragedModel(model)
49 | swa_scheduler = SWALR(optimizer, swa_lr=args.swa_lr)
50 |
51 | """ define trainer, evaluator, result_dictionary """
52 | result_dict = {'args':vars(args), 'epoch':[], 'train_loss' : [], 'train_acc' : [], 'val_loss' : [], 'val_acc' : [], 'test_acc':[]}
53 | trainer = Trainer(model, criterion, optimizer, scheduler, scaler, swa_model, swa_scheduler)
54 | evaluator = Evaluator(model, criterion, swa_model)
55 |
56 | train_time_list = []
57 | valid_time_list = []
58 |
59 | if args.evaluate:
60 | """ load model checkpoint """
61 | model.load()
62 | result_dict = evaluator.test(test_loader, args, result_dict)
63 | else:
64 | evaluator.save(result_dict)
65 |
66 | best_val_acc = 0.0
67 | """ define training loop """
68 | for epoch in range(args.epochs):
69 | result_dict['epoch'] = epoch
70 |
71 | torch.cuda.synchronize()
72 | tic1 = time.time()
73 |
74 | result_dict = trainer.train(train_loader, epoch, args, result_dict)
75 |
76 | torch.cuda.synchronize()
77 | tic2 = time.time()
78 | train_time_list.append(tic2 - tic1)
79 |
80 | torch.cuda.synchronize()
81 | tic3 = time.time()
82 |
83 | result_dict = evaluator.evaluate(valid_loader, epoch, args, result_dict)
84 |
85 | torch.cuda.synchronize()
86 | tic4 = time.time()
87 | valid_time_list.append(tic4 - tic3)
88 |
89 | if result_dict['val_acc'][-1] > best_val_acc:
90 | print("{} epoch, best epoch was updated! {}%".format(epoch, result_dict['val_acc'][-1]))
91 | best_val_acc = result_dict['val_acc'][-1]
92 | model.save(checkpoint_name='best_model')
93 |
94 | evaluator.save(result_dict)
95 | plot_learning_curves(result_dict, epoch, args)
96 |
97 | if args.decay_type == 'swa':
98 | swa_model = swa_model.cpu()
99 | torch.optim.swa_utils.update_bn(train_loader, swa_model)
100 | swa_model = swa_model.cuda()
101 |
102 | result_dict = evaluator.test(test_loader, args, result_dict)
103 | evaluator.save(result_dict)
104 |
105 | """ calculate test accuracy using best model """
106 | model.load(checkpoint_name='best_model')
107 | result_dict = evaluator.test(test_loader, args, result_dict)
108 | evaluator.save(result_dict)
109 |
110 | print(result_dict)
111 |
112 | np.savetxt(os.path.join(model.checkpoint_dir, model.checkpoint_name, 'train_time_amp.csv'), train_time_list, delimiter=',', fmt='%s')
113 | np.savetxt(os.path.join(model.checkpoint_dir, model.checkpoint_name, 'valid_time_amp.csv'), valid_time_list, delimiter=',', fmt='%s')
114 |
115 | if __name__ == '__main__':
116 | main()
--------------------------------------------------------------------------------
/network/resnet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | from torchvision.models import resnet
5 |
6 | __all__ = [
7 | 'ResNet18', 'ResNet50'
8 | ]
9 |
10 | class Flatten(nn.Module):
11 | def __init__(self):
12 | super().__init__()
13 |
14 | def forward(self, x):
15 | return x.view(x.size(0), -1)
16 |
17 | def conv3x3(in_channel, out_channel, stride=1):
18 | return nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=1, bias=False)
19 |
20 | def conv1x1(in_channel, out_channel, stride=1):
21 | return nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride, bias=False)
22 |
23 | class ResNet(nn.Module):
24 | def __init__(self, block, layer_config, num_classes=2, norm='batch', zero_init_residual=False):
25 | super(ResNet, self).__init__()
26 | norm = nn.BatchNorm2d
27 |
28 | self.in_channel = 64
29 |
30 | self.conv = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2, padding=3, bias=False)
31 | self.norm = norm(self.in_channel)
32 | self.relu = nn.ReLU(inplace=True)
33 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
34 | self.layer1 = self.make_layer(block, 64*block.expansion, layer_config[0], stride=1, norm=norm)
35 | self.layer2 = self.make_layer(block, 128*block.expansion, layer_config[1], stride=2, norm=norm)
36 | self.layer3 = self.make_layer(block, 256*block.expansion, layer_config[2], stride=2, norm=norm)
37 | self.layer4 = self.make_layer(block, 512*block.expansion, layer_config[3], stride=2, norm=norm)
38 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
39 | self.flatten = Flatten()
40 | self.dense = nn.Linear(512*block.expansion, num_classes)
41 |
42 | for m in self.modules():
43 | if isinstance(m, nn.Conv2d):
44 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
45 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
46 | nn.init.constant_(m.weight, 1)
47 | nn.init.constant_(m.bias, 0)
48 |
49 | # Zero-initialize the last BN in each residual branch,
50 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
51 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
52 | if zero_init_residual:
53 | for m in self.modules():
54 | if isinstance(m, BasicBlock):
55 | nn.init.constant_(m.norm2.weight, 0)
56 | elif isinstance(m, Bottleneck):
57 | nn.init.constant_(m.norm3.weight, 0)
58 |
59 | def make_layer(self, block, out_channel, num_blocks, stride=1, norm=None):
60 | norm = nn.BatchNorm2d
61 |
62 | downsample = None
63 | if stride != 1 or self.in_channel != out_channel:
64 | downsample = nn.Sequential(
65 | conv1x1(self.in_channel, out_channel, stride),
66 | norm(out_channel),
67 | )
68 | layers = []
69 | layers.append(block(self.in_channel, out_channel, stride, downsample, norm))
70 | self.in_channel = out_channel
71 | for _ in range(1, num_blocks):
72 | layers.append(block(self.in_channel, out_channel, norm=norm))
73 | return nn.Sequential(*layers)
74 |
75 | def forward(self, x):
76 | out = self.conv(x)
77 | out = self.norm(out)
78 | out = self.relu(out)
79 | out = self.maxpool(out)
80 | out = self.layer1(out)
81 | out = self.layer2(out)
82 | out = self.layer3(out)
83 | out = self.layer4(out)
84 | out = self.avgpool(out)
85 | out = self.flatten(out)
86 | out = self.dense(out)
87 | return out
88 |
89 |
90 | class BasicBlock(nn.Module):
91 | expansion = 1
92 |
93 | def __init__(self, in_channel, out_channel, stride=1, downsample=None, norm=None):
94 | super(BasicBlock, self).__init__()
95 | norm = nn.BatchNorm2d
96 |
97 | self.conv1 = conv3x3(in_channel, out_channel, stride)
98 | self.norm1 = norm(out_channel)
99 | self.conv2 = conv3x3(out_channel, out_channel)
100 | self.norm2 = norm(out_channel)
101 | self.relu = nn.ReLU(inplace=True)
102 | self.downsample = downsample
103 |
104 | def forward(self, x):
105 | identity = x
106 | out = self.conv1(x)
107 | out = self.norm1(out)
108 | out = self.relu(out)
109 | out = self.conv2(out)
110 | out = self.norm2(out)
111 | if self.downsample is not None:
112 | identity = self.downsample(identity)
113 | out += identity
114 | out = self.relu(out)
115 | return out
116 |
117 | class Bottleneck(nn.Module):
118 | expansion = 4
119 |
120 | def __init__(self, in_channel, out_channel, stride=1, downsample=None, norm=None):
121 | super(Bottleneck, self).__init__()
122 | norm = nn.BatchNorm2d
123 |
124 | mid_channel= out_channel // self.expansion
125 | self.conv1 = conv1x1(in_channel, mid_channel)
126 | self.norm1 = norm(mid_channel)
127 | self.conv2 = conv3x3(mid_channel, mid_channel, stride)
128 | self.norm2 = norm(mid_channel)
129 | self.conv3 = conv1x1(mid_channel, out_channel)
130 | self.norm3 = norm(out_channel)
131 | self.relu = nn.ReLU(inplace=True)
132 | self.downsample = downsample
133 |
134 | def forward(self, x):
135 | identity = x
136 | out = self.conv1(x)
137 | out = self.norm1(out)
138 | out = self.relu(out)
139 | out = self.conv2(out)
140 | out = self.norm2(out)
141 | out = self.relu(out)
142 | out = self.conv3(out)
143 | out = self.norm3(out)
144 | if self.downsample is not None:
145 | identity = self.downsample(x)
146 | out += identity
147 | out = self.relu(out)
148 | return out
149 |
150 |
151 | class ResNet18(nn.Module):
152 | def __init__(self, shape, num_classes=2, checkpoint_dir='checkpoint', checkpoint_name='ResNet18',
153 | pretrained=False, pretrained_path=None, norm='batch', zero_init_residual=False):
154 | super(ResNet18, self).__init__()
155 |
156 | if len(shape) != 3:
157 | raise ValueError('Invalid shape: {}'.format(shape))
158 | self.shape = shape
159 | self.num_classes = num_classes
160 | self.checkpoint_dir = checkpoint_dir
161 | self.checkpoint_name = checkpoint_name
162 | self.H, self.W, self.C = shape
163 |
164 | if not os.path.exists(checkpoint_dir):
165 | os.makedirs(checkpoint_dir)
166 | self.checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name, 'model.pt')
167 |
168 | model = ResNet(BasicBlock, [2,2,2, 2], num_classes, norm, zero_init_residual)
169 |
170 | if pretrained:
171 | print("Pretrained Weight is loaded!!")
172 | if pretrained_path is None:
173 | print("Loading from torchvision models")
174 | model = resnet.resnet18(pretrained=True)
175 | if zero_init_residual:
176 | for m in model.modules():
177 | if isinstance(m, resnet.Bottleneck):
178 | nn.init.constant_(m.bn3.weight, 0)
179 | elif isinstance(m, BasicBlock):
180 | nn.init.constant_(m.bn2.weight, 0)
181 | else:
182 | checkpoint = torch.load(pretrained_path)
183 | model.load_state_dict(checkpoint)
184 |
185 | self.features = nn.Sequential(*list(model.children())[:-2])
186 | self.num_features = 512 * BasicBlock.expansion
187 | self.classifier = nn.Sequential(
188 | nn.AdaptiveAvgPool2d((1, 1)),
189 | Flatten(),
190 | nn.Linear(self.num_features, num_classes)
191 | )
192 |
193 | def save(self, checkpoint_name=''):
194 | if checkpoint_name == '':
195 | torch.save(self.state_dict(), self.checkpoint_path)
196 | else:
197 | checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name, checkpoint_name + '.pt')
198 | torch.save(self.state_dict(), checkpoint_path)
199 |
200 | def load(self, checkpoint_name=''):
201 | if checkpoint_name == '':
202 | self.load_state_dict(torch.load(self.checkpoint_path))
203 | else:
204 | checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name, checkpoint_name + '.pt')
205 | self.load_state_dict(torch.load(checkpoint_path))
206 |
207 | def forward(self, x):
208 | out = x
209 | out = self.features(out)
210 | out = self.classifier(out)
211 | return out
212 |
213 | class ResNet50(nn.Module):
214 | def __init__(self, shape, num_classes=2, checkpoint_dir='checkpoint', checkpoint_name='ResNet50',
215 | pretrained=False, pretrained_path=None, norm='batch', zero_init_residual=False):
216 | super(ResNet50, self).__init__()
217 |
218 | if len(shape) != 3:
219 | raise ValueError('Invalid shape: {}'.format(shape))
220 | self.shape = shape
221 | self.num_classes = num_classes
222 | self.checkpoint_dir = checkpoint_dir
223 | self.checkpoint_name = checkpoint_name
224 | self.H, self.W, self.C = shape
225 |
226 | if not os.path.exists(checkpoint_dir):
227 | os.makedirs(checkpoint_dir)
228 | self.checkpoint_path = os.path.join(checkpoint_dir, checkpoint_name, 'model.pt')
229 |
230 | model = ResNet(Bottleneck, [3,4,6,3], num_classes, norm, zero_init_residual)
231 |
232 | if pretrained:
233 | print("Pretrained Weight is loaded!!")
234 | if pretrained_path is None:
235 | print("Loading from torchvision models")
236 | model = resnet.resnet50(pretrained=True)
237 | if zero_init_residual:
238 | for m in model.modules():
239 | if isinstance(m, resnet.Bottleneck):
240 | nn.init.constant_(m.bn3.weight, 0)
241 | else:
242 | checkpoint = torch.load(pretrained_path)
243 | model.load_state_dict(checkpoint)
244 |
245 | self.features = nn.Sequential(*list(model.children())[:-2])
246 | self.num_features = 512 * Bottleneck.expansion
247 | self.classifier = nn.Sequential(
248 | nn.AdaptiveAvgPool2d((1, 1)),
249 | Flatten(),
250 | nn.Linear(self.num_features, num_classes)
251 | )
252 |
253 | def save(self, checkpoint_name=''):
254 | if checkpoint_name == '':
255 | torch.save(self.state_dict(), self.checkpoint_path)
256 | else:
257 | checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name, checkpoint_name + '.pt')
258 | torch.save(self.state_dict(), checkpoint_path)
259 |
260 | def load(self, checkpoint_name=''):
261 | if checkpoint_name == '':
262 | self.load_state_dict(torch.load(self.checkpoint_path))
263 | else:
264 | checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name, checkpoint_name + '.pt')
265 | self.load_state_dict(torch.load(checkpoint_path))
266 |
267 | def forward(self, x):
268 | out = x
269 | out = self.features(out)
270 | out = self.classifier(out)
271 | return out
--------------------------------------------------------------------------------
/option.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def get_args():
4 | parser = argparse.ArgumentParser()
5 |
6 | # model architecture & checkpoint
7 | parser.add_argument('--model', default='ResNet18', choices=('ResNet18', 'ResNet50'),
8 | help='optimizer to use (ResNet18 | ResNet50)')
9 | parser.add_argument('--norm', default='batchnorm')
10 | parser.add_argument('--num_classes', type=int, default=6)
11 | parser.add_argument('--pretrained', type=int, default=1)
12 | parser.add_argument('--pretrained_path', type=str, default=None)
13 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint')
14 | parser.add_argument('--checkpoint_name', type=str, default='')
15 |
16 | # data loading
17 | parser.add_argument('--num_workers', type=int, default=16)
18 | parser.add_argument('--seed', type=int, default=42, help='random seed')
19 |
20 | # training hyper parameters
21 | parser.add_argument('--batch_size', type=int, default=256)
22 | parser.add_argument('--epochs', type=int, default=120)
23 | parser.add_argument('--log_interval', type=int, default=20)
24 | parser.add_argument('--evaluate', action='store_true', default=False)
25 | parser.add_argument('--amp', action='store_true', default=False)
26 |
27 | # optimzier & learning rate scheduler
28 | parser.add_argument('--learning_rate', type=float, default=0.0001)
29 | parser.add_argument('--weight_decay', type=float, default=0.0001)
30 | parser.add_argument('--optimizer', default='ADAM', choices=('SGD', 'ADAM'),
31 | help='optimizer to use (SGD | ADAM)')
32 | parser.add_argument('--decay_type', default='cosine_warmup', choices=('step', 'step_warmup', 'cosine_warmup', 'swa'),
33 | help='optimizer to use (step | step_warmup | cosine_warmup | stochastic weight averaging)')
34 | parser.add_argument('--swa_start', type=int, default=90)
35 | parser.add_argument('--swa_lr', type=float, default=0.05)
36 |
37 |
38 | args = parser.parse_args()
39 | return args
40 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | captum==0.2.0
2 | cycler==0.10.0
3 | decorator==4.4.2
4 | decord==0.4.0
5 | future==0.18.2
6 | imageio==2.8.0
7 | joblib==0.16.0
8 | kiwisolver==1.2.0
9 | matplotlib==3.2.1
10 | networkx==2.4
11 | numpy==1.18.5
12 | opencv-python==4.2.0.34
13 | pandas==1.0.4
14 | Pillow==7.1.2
15 | pyparsing==2.4.7
16 | python-dateutil==2.8.1
17 | pytz==2020.1
18 | PyWavelets==1.1.1
19 | scikit-image==0.17.2
20 | scikit-learn==0.23.1
21 | scipy==1.4.1
22 | six==1.15.0
23 | threadpoolctl==2.1.0
24 | tifffile==2020.6.3
25 | torch==1.6.0
26 | torchsummary==1.5.1
27 | torchvision==0.7.0
28 | tqdm==4.48.2
29 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib
3 | matplotlib.use('Agg')
4 | import matplotlib.pyplot as plt
5 | from sklearn.model_selection import train_test_split
6 |
7 | import torch
8 | import torch.optim as optim
9 | import torch.optim.lr_scheduler as lrs
10 | from torch.utils.data.sampler import SubsetRandomSampler
11 | import torchvision
12 | from torchvision import transforms as T
13 |
14 | from network.resnet import *
15 | from learning.lr_scheduler import GradualWarmupScheduler
16 |
17 | def get_model(args, shape, num_classes):
18 | model = eval(args.model)(
19 | shape,
20 | num_classes,
21 | checkpoint_dir=args.checkpoint_dir,
22 | checkpoint_name=args.checkpoint_name,
23 | pretrained=args.pretrained,
24 | pretrained_path=args.pretrained_path,
25 | norm=args.norm,
26 | )
27 | return model
28 |
29 |
30 | def make_optimizer(args, model):
31 | trainable = filter(lambda x: x.requires_grad, model.parameters())
32 |
33 | if args.optimizer == 'SGD':
34 | optimizer_function = optim.SGD
35 | kwargs = {'momentum': 0.9}
36 | elif args.optimizer == 'ADAM':
37 | optimizer_function = optim.Adam
38 | kwargs = {
39 | 'betas': (0.9, 0.999),
40 | 'eps': 1e-08
41 | }
42 | else:
43 | raise NameError('Not Supportes Optimizer')
44 |
45 | kwargs['lr'] = args.learning_rate
46 | kwargs['weight_decay'] = args.weight_decay
47 |
48 | return optimizer_function(trainable, **kwargs)
49 |
50 |
51 | def make_scheduler(args, optimizer):
52 | if args.decay_type == 'step':
53 | scheduler = lrs.MultiStepLR(
54 | optimizer,
55 | milestones=[30, 60, 90],
56 | gamma=0.1
57 | )
58 | elif args.decay_type == 'step_warmup':
59 | scheduler = lrs.MultiStepLR(
60 | optimizer,
61 | milestones=[30, 60, 90],
62 | gamma=0.1
63 | )
64 | scheduler = GradualWarmupScheduler(
65 | optimizer,
66 | multiplier=1,
67 | total_epoch=5,
68 | after_scheduler=scheduler
69 | )
70 | elif args.decay_type == 'cosine_warmup':
71 | cosine_scheduler = lrs.CosineAnnealingLR(
72 | optimizer,
73 | T_max=args.epochs
74 | )
75 | scheduler = GradualWarmupScheduler(
76 | optimizer,
77 | multiplier=1,
78 | total_epoch=args.epochs//10,
79 | after_scheduler=cosine_scheduler
80 | )
81 | elif args.decay_type == 'swa':
82 | scheduler = lrs.CosineAnnealingLR(
83 | optimizer,
84 | T_max=args.epochs
85 | )
86 | else:
87 | raise Exception('unknown lr scheduler: {}'.format(args.decay_type))
88 |
89 | return scheduler
90 |
91 | def make_dataloader(args):
92 |
93 | train_trans = T.Compose([
94 | T.Resize((256, 256)),
95 | T.RandomHorizontalFlip(),
96 | T.ToTensor(),
97 | ])
98 |
99 | valid_trans = T.Compose([
100 | T.Resize((256, 256)),
101 | T.ToTensor(),
102 | ])
103 |
104 | test_trans = T.Compose([
105 | T.Resize((256, 256)),
106 | T.ToTensor(),
107 | ])
108 |
109 | trainset = torchvision.datasets.ImageFolder(root="data/seg_train/seg_train", transform=train_trans)
110 | validset = torchvision.datasets.ImageFolder(root="data/seg_train/seg_train", transform=valid_trans)
111 | testset = torchvision.datasets.ImageFolder(root="data/seg_test/seg_test", transform=test_trans)
112 |
113 | np.random.seed(args.seed)
114 | targets = trainset.targets
115 | train_idx, valid_idx = train_test_split(np.arange(len(targets)), test_size=0.2, shuffle=True, stratify=targets)
116 | train_sampler = SubsetRandomSampler(train_idx)
117 | valid_sampler = SubsetRandomSampler(valid_idx)
118 |
119 | train_loader = torch.utils.data.DataLoader(
120 | trainset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.num_workers
121 | )
122 |
123 | valid_loader = torch.utils.data.DataLoader(
124 | validset, batch_size=args.batch_size, sampler=valid_sampler, num_workers=args.num_workers
125 | )
126 |
127 | test_loader = torch.utils.data.DataLoader(
128 | testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers
129 | )
130 |
131 | return train_loader, valid_loader, test_loader
132 |
133 |
134 | def plot_learning_curves(metrics, cur_epoch, args):
135 | x = np.arange(cur_epoch+1)
136 | fig, ax1 = plt.subplots()
137 | ax1.set_xlabel('epochs')
138 | ax1.set_ylabel('loss')
139 | ln1 = ax1.plot(x, metrics['train_loss'], color='tab:red')
140 | ln2 = ax1.plot(x, metrics['val_loss'], color='tab:red', linestyle='dashed')
141 | ax1.grid()
142 | ax2 = ax1.twinx()
143 | ax2.set_ylabel('accuracy')
144 | ln3 = ax2.plot(x, metrics['train_acc'], color='tab:blue')
145 | ln4 = ax2.plot(x, metrics['val_acc'], color='tab:blue', linestyle='dashed')
146 | lns = ln1+ln2+ln3+ln4
147 | plt.legend(lns, ['Train loss', 'Validation loss', 'Train accuracy','Validation accuracy'])
148 | plt.tight_layout()
149 | plt.savefig('{}/{}/learning_curve.png'.format(args.checkpoint_dir, args.checkpoint_name), bbox_inches='tight')
150 | plt.close('all')
151 |
152 | class AverageMeter(object):
153 | """Computes and stores the average and current value"""
154 | def __init__(self):
155 | self.reset()
156 |
157 | def reset(self):
158 | self.val = 0
159 | self.avg = 0
160 | self.sum = 0
161 | self.count = 0
162 | self.max = 0
163 | self.min = 1e5
164 |
165 | def update(self, val, n=1):
166 | self.val = val
167 | self.sum += val * n
168 | self.count += n
169 | self.avg = self.sum / self.count
170 | if val > self.max:
171 | self.max = val
172 | if val < self.min:
173 | self.min = val
174 |
175 | def accuracy(output, target, topk=(1,)):
176 | """Computes the precision@k for the specified values of k"""
177 | maxk = max(topk)
178 | batch_size = target.size(0)
179 |
180 | _, pred = output.topk(maxk, 1, True, True)
181 | pred = pred.t()
182 | correct = pred.eq(target.view(1, -1).expand_as(pred))
183 |
184 | res = []
185 | for k in topk:
186 | correct_k = correct[:k].view(-1).float().sum(0)
187 | res.append(correct_k.mul_(100.0 / batch_size))
188 | return res
--------------------------------------------------------------------------------