├── requirements.txt ├── README ├── code.png ├── tse.png ├── vgg16_0.5_se_pre_attn_none.png ├── vgg16_0.5_tse_pre_attn_none.png └── vgg16_0.5_none_pre_attn_none.png ├── results └── results.txt ├── LICENSE ├── tse.py ├── README.md ├── utils.py └── example.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision~=0.6.1 2 | tqdm~=4.42.1 3 | fire~=0.3.1 4 | matplotlib~=3.1.3 -------------------------------------------------------------------------------- /README/code.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuranusduke/Tiled-Squeeze-and-Excitation/HEAD/README/code.png -------------------------------------------------------------------------------- /README/tse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuranusduke/Tiled-Squeeze-and-Excitation/HEAD/README/tse.png -------------------------------------------------------------------------------- /README/vgg16_0.5_se_pre_attn_none.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuranusduke/Tiled-Squeeze-and-Excitation/HEAD/README/vgg16_0.5_se_pre_attn_none.png -------------------------------------------------------------------------------- /README/vgg16_0.5_tse_pre_attn_none.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuranusduke/Tiled-Squeeze-and-Excitation/HEAD/README/vgg16_0.5_tse_pre_attn_none.png -------------------------------------------------------------------------------- /README/vgg16_0.5_none_pre_attn_none.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuranusduke/Tiled-Squeeze-and-Excitation/HEAD/README/vgg16_0.5_none_pre_attn_none.png -------------------------------------------------------------------------------- /results/results.txt: -------------------------------------------------------------------------------- 1 | 2021-07-08 15:34:39.371222 ::: [attn_ratio : 0.5, attn_method : tse_pre_attn_none] --> acc : 92.16%. 2 | 2021-07-08 17:12:05.104624 ::: [attn_ratio : 0.5, attn_method : se_pre_attn_none] --> acc : 91.63%. 3 | 2021-07-08 18:05:23.101567 ::: [attn_ratio : 0.5, attn_method : none_pre_attn_none] --> acc : 90.73%. 4 | 2021-07-09 09:17:30.741822 ::: [attn_ratio : 0.5, attn_method : se_pre_attn_tse_to_se] --> acc : 80.31%. 5 | 2021-07-09 09:18:07.072458 ::: [attn_ratio : 0.5, attn_method : tse_pre_attn_se_to_tse] --> acc : 70.04%. 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 yuranus 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tse.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file builds Tiled Squeeze-and-Excitation(TSE) from paper: 3 | --> https://arxiv.org/abs/2107.02145 4 | 5 | Created by Kunhong Yu 6 | Date: 2021/07/06 7 | """ 8 | import torch as t 9 | 10 | def weights_init(layer): 11 | """ 12 | weights initialization 13 | Args : 14 | --layer: one layer instance 15 | """ 16 | if isinstance(layer, t.nn.Linear) or isinstance(layer, t.nn.BatchNorm1d): 17 | t.nn.init.normal_(layer.weight, 0.0, 0.02) # we use 0.02 as initial value 18 | t.nn.init.constant_(layer.bias, 0.0) 19 | 20 | class TSE(t.nn.Module): 21 | """Define TSE operation""" 22 | """According to the paper, simple TSE can be implemented by 23 | several 1x1 conv followed by a average pooling with kernel size and stride, 24 | which is simple and effective to verify and to do parameter sharing 25 | In this implementation, column and row pooling kernel sizes are shared! 26 | """ 27 | 28 | def __init__(self, num_channels : int, attn_ratio : float, pool_kernel = 7): 29 | """ 30 | Args : 31 | --num_channels: # of input channels 32 | --attn_ratio: hidden size ratio 33 | --pool_kernel: pooling kernel size, default best is 7 according to paper 34 | """ 35 | super().__init__() 36 | 37 | self.num_channels = num_channels 38 | 39 | self.sigmoid = t.nn.Sigmoid() 40 | 41 | self.avg_pool = t.nn.AvgPool2d(kernel_size = pool_kernel, stride = pool_kernel, ceil_mode = True) 42 | 43 | self.tse = t.nn.Sequential( 44 | t.nn.Conv2d(self.num_channels, int(self.num_channels * attn_ratio), kernel_size = 1, stride = 1), 45 | t.nn.BatchNorm2d(int(self.num_channels * attn_ratio)), 46 | t.nn.ReLU(inplace = True), 47 | 48 | t.nn.Conv2d(int(self.num_channels * attn_ratio), self.num_channels, kernel_size = 1, stride = 1), 49 | t.nn.Sigmoid() 50 | ) 51 | self.kernel_size = pool_kernel 52 | 53 | def forward(self, x): 54 | """x has shape [m, C, H, W]""" 55 | _, C, H, W = x.size() 56 | # 1. TSE 57 | y = self.tse(self.avg_pool(x)) 58 | 59 | # 2. Re-calibrated 60 | y = t.repeat_interleave(y, self.kernel_size, dim = -2)[:, :, :H, :] 61 | y = t.repeat_interleave(y, self.kernel_size, dim = -1)[:, :, :, :W] 62 | 63 | return x * y 64 | 65 | # unit test 66 | if __name__ == '__main__': 67 | tse = TSE(1024, 0.5, 7) 68 | print(tse) 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ***Tiled Squeeze-and-Excite: Channel Attention With Local Spatial Context*** 2 | 3 | First of all, thank the authors very much for sharing this excellent paper ***Tiled Squeeze-and-Excite: Channel Attention With Local Spatial Context*** with us. 4 | Paper addr: https://arxiv.org/abs/2107.02145 5 | 6 | This repository contains unofficial implementation of TSE Attention and simple verification for modified VGG16 7 | with CIFAR10. If there are some bug problems in the 8 | implementation, please send me an email at yuranusduke@163.com or simply add issue. 9 | 10 | ## Backgrounds 11 | In this paper, authors believe that in original SE net, Global Average Pooling(GAP) is not necessary, and 12 | most importantly, in the backprop, AI hardware will store whole feature map according to GAP, which gives a bottleneck 13 | in time efficiency. Therefore, instead of using GAP, they propose to use average pooling with kernel 14 | and stride, and use `Conv2d` with 1x1 kernel size in the following nonlinear transformation. 15 | They propose to use larger kernel size(e.g. 7) in the deeper layer. But in this repo, we experiment with 16 | small images(CIFAR10) instead of ImageNet, so we use kernel size 2. 17 | 18 | ![img](./README/tse.png) 19 | 20 | And official pseudo code is, 21 | 22 | ![img](./README/code.png) 23 | 24 | One can read paper carefully to understand how and why they design architecture like this. 25 | 26 | ## Requirements 27 | 28 | ```Python 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | ## Implementation 33 | 34 | We simply run CIFAR10 with modified VGG16. 35 | 36 | ### Hyper-parameters and defaults 37 | ```bash 38 | --device = 'cuda' # learning device 39 | --attn_ratio = 0.5 # hidden size ratio 40 | --attn_method = 'se' # 'se' or 'none' or 'tse' 41 | --pre_attn = 'se_to_tse' # 'none' or 'se_to_tse' or 'tse_to_se' 42 | --epochs=80 # training epochs 43 | --batch_size=64 # batch size 44 | --init_lr=0.1 # initial learning rate 45 | --gamma=0.2 # learning rate decay 46 | --milestones=[40,60,80] # learning rate decay milestones 47 | --weight_decay=9e-6 # weight decay 48 | ``` 49 | 50 | ### Train & Test 51 | 52 | ```python 53 | python example.py main \ 54 | --device='cuda' \ 55 | --attn_ratio=0.5 \ 56 | --attn_method='se' \ 57 | --pre_attn='none' \ 58 | --epochs=100 \ 59 | --batch_size=20 \ 60 | --init_lr=0.1 \ 61 | --gamma=0.2 \ 62 | --milestones=[20,40,60,80] \ 63 | --weight_decay=1e-5 64 | 65 | ``` 66 | 67 | ## Results 68 | 69 | ### TSE & SE with learning from scratch 70 | | Model | Acc. | 71 | | ----------------- | ----------- | 72 | | baseline | 90.73% | 73 | | SE | 91.63% | 74 | | TSE | **92.16%** | 75 | 76 | ### TSE & SE with pretrained weights 77 | 78 | According to paper, we use pretrained SE net to inject SE modules directly into TSE net in 79 | test, and vice versa. 80 | From this table, we observe, though we did not get ideal results, attention modules transferred 81 | from TSE to SE are much more efficient than those transferred from SE to TSE. 82 | 83 | | Method | Acc. | 84 | | ----------------- | ----------- | 85 | | SE->TSE | 70.04% | 86 | | TSE->SE | **80.31%** | 87 | 88 | 89 | ## Training statistics 90 | 91 | ### Baseline 92 | ![img](README/vgg16_0.5_none_pre_attn_none.png) 93 | 94 | ### TSE with learning from scratch 95 | ![img](README/vgg16_0.5_tse_pre_attn_none.png) 96 | 97 | ### SE with learning from scratch 98 | ![img](README/vgg16_0.5_se_pre_attn_none.png) 99 | 100 | 101 | ***
Veni,vidi,vici --Caesar
*** 102 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Define some utilities 3 | 4 | Created by Kunhong Yu 5 | Date: 2021/07/06 6 | """ 7 | import torch as t 8 | from torch.nn import functional as F 9 | from tse import TSE 10 | 11 | #################################### 12 | # Define utilities # 13 | #################################### 14 | def _conv_layer(in_channels : int, 15 | out_channels : int) -> t.nn.Sequential: 16 | """Define conv layer 17 | Args : 18 | --in_channels: input channels 19 | --out_channels: output channels 20 | return : 21 | --conv layer 22 | """ 23 | conv_layer = t.nn.Sequential( 24 | t.nn.Conv2d(in_channels = in_channels, out_channels = out_channels, 25 | kernel_size = 3, stride = 1, padding = 1), 26 | t.nn.BatchNorm2d(out_channels), 27 | t.nn.ReLU(inplace = True) 28 | ) 29 | 30 | return conv_layer 31 | 32 | 33 | def vgg_block(in_channels : int, 34 | out_channels : int, 35 | repeat : int) -> list: 36 | """Define VGG block 37 | Args : 38 | --in_channels: input channels 39 | --out_channels: output channels 40 | --repeat 41 | return : 42 | --block 43 | """ 44 | block = [ 45 | _conv_layer(in_channels = in_channels if i == 0 else out_channels, 46 | out_channels = out_channels) 47 | for i in range(repeat) 48 | ] 49 | 50 | return block 51 | 52 | #################################### 53 | # Define SE # 54 | #################################### 55 | class SEAttention(t.nn.Module): 56 | """Define SE operation""" 57 | 58 | def __init__(self, num_channels : int, attn_ratio : float): 59 | """ 60 | Args : 61 | --num_channels: # of input channels 62 | --attn_ratio: hidden size ratio 63 | """ 64 | super(SEAttention, self).__init__() 65 | 66 | self.num_channels = num_channels 67 | self.hidden_size = int(attn_ratio * self.num_channels) 68 | 69 | # 1. Trunk, we use T(x) = x 70 | # 2. SE attention 71 | self.SE = t.nn.Sequential( 72 | t.nn.Linear(self.num_channels, self.hidden_size), 73 | t.nn.BatchNorm1d(self.hidden_size), 74 | t.nn.ReLU(inplace = True), 75 | 76 | t.nn.Linear(self.hidden_size, self.num_channels), 77 | t.nn.BatchNorm1d(self.num_channels), 78 | t.nn.Sigmoid() 79 | ) 80 | 81 | def forward(self, x): 82 | # 1. T(x) 83 | Tx = x 84 | # 2. SE attention 85 | x = F.adaptive_avg_pool2d(x, (1, 1)) # global average pooling 86 | x = x.squeeze() 87 | Ax = self.SE(x) 88 | 89 | # 3. output 90 | x = Tx * t.unsqueeze(t.unsqueeze(Ax, dim = -1), dim = -1) # broadcasting 91 | 92 | return x 93 | 94 | def get_attention(channels : int, attn_ratio = 0.5, pool_kernel = 7, method = 'se'): 95 | """Get attention method 96 | Args : 97 | --channels: number of input channels 98 | --attn_ratio: attention ratio, default is 0.5 99 | --pool_kernel: 7 as default according to paper 100 | --method: 'se' or 'tse', default is 'se' 101 | return : 102 | --attn: attention method 103 | """ 104 | if method == 'se': 105 | attn = SEAttention(num_channels = channels, 106 | attn_ratio = attn_ratio) 107 | elif method == 'tse': 108 | attn = TSE(num_channels = channels, attn_ratio = attn_ratio, pool_kernel = pool_kernel) 109 | 110 | else: 111 | raise Exception('No other attentions!') 112 | 113 | return attn 114 | 115 | #################################### 116 | # Define VGG16 # 117 | #################################### 118 | class VGG16(t.nn.Module): 119 | """Define VGG16-style model""" 120 | 121 | def __init__(self, attn_method = 'none', attn_ratio = 0.5): 122 | """ 123 | Args : 124 | --attn_method: 'none'/'se'/'tse' 125 | --attn_ratio: hidden size ratio, default is 0.5 126 | --pre_attn: None or [att1, attn2, ...] 127 | """ 128 | super(VGG16, self).__init__() 129 | 130 | self.attn_method = attn_method 131 | 132 | self.layer1 = t.nn.Sequential(*vgg_block(in_channels = 3, 133 | out_channels = 64, 134 | repeat = 2)) 135 | 136 | if self.attn_method != 'none': 137 | self.attn1 = get_attention(channels = 64, attn_ratio = attn_ratio, pool_kernel = 2, method = self.attn_method) 138 | 139 | self.layer2 = t.nn.Sequential(*vgg_block(in_channels = 64, 140 | out_channels = 128, 141 | repeat = 2)) 142 | 143 | if self.attn_method != 'none': 144 | self.attn2 = get_attention(channels = 128, attn_ratio = attn_ratio, pool_kernel = 2, method = self.attn_method) 145 | 146 | self.layer3 = t.nn.Sequential(*vgg_block(in_channels = 128, 147 | out_channels = 256, 148 | repeat = 3)) 149 | 150 | if self.attn_method != 'none': 151 | self.attn3 = get_attention(channels = 256, attn_ratio = attn_ratio, pool_kernel = 2, method = self.attn_method) 152 | 153 | self.layer4 = t.nn.Sequential(*vgg_block(in_channels = 256, 154 | out_channels = 512, 155 | repeat = 3)) 156 | 157 | if self.attn_method != 'none': 158 | self.attn4 = get_attention(channels = 512, attn_ratio = attn_ratio, pool_kernel = 2, method = self.attn_method) 159 | 160 | self.fc = t.nn.Sequential( # unlike original VGG16, I reduce some fc 161 | # parameters to fit my 2070 device 162 | t.nn.Linear(512, 256), 163 | t.nn.ReLU(inplace = True), 164 | t.nn.Linear(256, 10) 165 | ) 166 | 167 | self.max_pool = t.nn.MaxPool2d(kernel_size = 2, stride = 2) 168 | 169 | def forward(self, x, pre_attn = None): 170 | """pre_attn is None or [attn1, attn2, ...]""" 171 | assert pre_attn is None or len(pre_attn) == 4 172 | 173 | x1 = self.layer1(x) 174 | if self.attn_method != 'none': 175 | if pre_attn is None: 176 | x1 = self.attn1(x1) 177 | else: 178 | x1 = pre_attn[0](x1) 179 | x1 = self.max_pool(x1) 180 | 181 | x2 = self.layer2(x1) 182 | if self.attn_method != 'none': 183 | if pre_attn is None: 184 | x2 = self.attn2(x2) 185 | else: 186 | x2 = pre_attn[1](x2) 187 | x2 = self.max_pool(x2) 188 | 189 | x3 = self.layer3(x2) 190 | if self.attn_method != 'none': 191 | if pre_attn is None: 192 | x3 = self.attn3(x3) 193 | else: 194 | x3 = pre_attn[2](x3) 195 | x3 = self.max_pool(x3) 196 | 197 | x4 = self.layer4(x3) 198 | if self.attn_method != 'none': 199 | if pre_attn is None: 200 | x4 = self.attn4(x4) 201 | else: 202 | x4 = pre_attn[3](x4) 203 | x4 = self.max_pool(x4) 204 | 205 | x = F.adaptive_avg_pool2d(x4, (1, 1)) 206 | x = x.squeeze() 207 | x = self.fc(x) 208 | 209 | return x -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | One Cifar10 example of the results 4 | I use modified VGG16(less parameters in fc and with bn) and CIFAR-10 5 | But I use TSE and SE at the end of each block unlike in original paper 6 | 7 | Created by Kunhong Yu 8 | Date: 2021/07/06 9 | """ 10 | import torch as t 11 | import torchvision as tv 12 | from tse import weights_init 13 | import tqdm 14 | import fire 15 | import matplotlib.pyplot as plt 16 | from utils import VGG16 17 | import datetime 18 | 19 | ############Hyper-parameters############ 20 | device = 'cuda' # learning device 21 | attn_ratio = 0.5 # hidden size ratio 22 | attn_method = 'se' 23 | pre_attn = 'se->tse' # 'none' or 'se->tse' or 'tse->se' 24 | 25 | epochs = 100 # training epochs 26 | batch_size = 32 # batch size 27 | init_lr = 0.1 # initial learnign rate 28 | gamma = 0.2 # learning rate decay, here we use step decay strategy for simplicity 29 | milestones = [20, 40, 60, 80] # learning rate decay epochs 30 | weight_decay = 1e-5 # weight decay 31 | ######################################## 32 | 33 | # Step 0 Decide the structure of the model# 34 | # Step 1 Load the data set# 35 | def main(**kwargs): 36 | """Simply run in one cell""" 37 | 38 | ###################### 39 | # Unfold parameters # 40 | ###################### 41 | device = kwargs['device'] 42 | attn_ratio = kwargs['attn_ratio'] 43 | attn_method = kwargs['attn_method'] 44 | pre_attn = kwargs['pre_attn'] 45 | 46 | epochs = kwargs['epochs'] 47 | batch_size = kwargs['batch_size'] 48 | init_lr = kwargs['init_lr'] 49 | gamma = kwargs['gamma'] 50 | milestones = kwargs['milestones'] 51 | weight_decay = kwargs['weight_decay'] 52 | 53 | if 'only_test' in kwargs: 54 | only_test = kwargs['only_test'] 55 | 56 | device = t.device(device) 57 | 58 | transform = \ 59 | tv.transforms.Compose([ 60 | tv.transforms.RandomCrop(32, padding = 4), 61 | tv.transforms.RandomHorizontalFlip(), 62 | tv.transforms.RandomRotation(15), 63 | tv.transforms.ToTensor(), 64 | tv.transforms.Normalize(0.5, 0.5)]) 65 | 66 | train_data = tv.datasets.CIFAR10(root = './data', 67 | download = True, 68 | train = True, 69 | transform = transform) 70 | test_data = tv.datasets.CIFAR10(root = './data', 71 | download = True, 72 | train = False, 73 | transform = tv.transforms.Compose([ 74 | tv.transforms.ToTensor(), 75 | tv.transforms.Normalize(0.5, 0.5) 76 | ])) 77 | 78 | train_loader = t.utils.data.DataLoader(train_data, 79 | shuffle = True, 80 | batch_size = batch_size) 81 | test_loader = t.utils.data.DataLoader(test_data, 82 | shuffle = False, 83 | batch_size = batch_size) 84 | 85 | # Step 2 Reshape the inputs# 86 | # Step 3 Normalize the inputs# 87 | # Step 4 Initialize parameters# 88 | # Step 5 Forward propagation(Vectorization/Activation functions)# 89 | 90 | global model 91 | model = VGG16(attn_method = attn_method, attn_ratio = attn_ratio) 92 | model.apply(weights_init) 93 | model.to(device) 94 | print('VGG model : \n', model) 95 | 96 | # Step 6 Compute cost# 97 | loss = t.nn.CrossEntropyLoss().to(device) 98 | # Step 7 Backward propagation(Vectorization/Activation functions gradients)# 99 | optimizer = t.optim.SGD(filter(lambda x : x.requires_grad, model.parameters()), 100 | lr = init_lr, momentum = 0.9, weight_decay = weight_decay, nesterov = True) 101 | 102 | lr_scheduler = t.optim.lr_scheduler.MultiStepLR(optimizer, 103 | gamma = gamma, 104 | milestones = milestones) 105 | 106 | def eval(model, eval_iter, device, loss = None, pre_attn = 'none'): 107 | """This function is used to evaluate model 108 | Args : 109 | --model: model instance 110 | --eval_iter: evaluation data iter 111 | --device 112 | --loss: default is None 113 | --pre_attn: None or [attn1, attn2, ...] 114 | return : 115 | --eval_loss: eval loss 116 | --eval_acc: eval acc 117 | """ 118 | model.eval() 119 | if pre_attn == 'none': 120 | pre_attn = None 121 | 122 | else: 123 | if pre_attn == 'se_to_tse': 124 | pre_model = t.load('./results/vgg16_' + str(attn_ratio) + '_se' + '_pre_attn_none' + '.pth') 125 | elif pre_attn == 'tse_to_se': 126 | pre_model = t.load('./results/vgg16_' + str(attn_ratio) + '_tse' + '_pre_attn_none' + '.pth') 127 | pre_model.to(device) 128 | pre_attn = [pre_model.attn1, pre_model.attn2, pre_model.attn3, pre_model.attn4] 129 | 130 | with t.no_grad(): 131 | count = 0. 132 | eval_loss = 0. 133 | eval_acc = 0. 134 | for i, (batch_x, batch_y) in enumerate(eval_iter): 135 | batch_x = batch_x.view(batch_x.size(0), 3, 32, 32) 136 | batch_x = batch_x.to(device) 137 | batch_y = batch_y.to(device) 138 | 139 | out = model(batch_x, pre_attn = pre_attn) 140 | preds = t.argmax(out, dim = 1) 141 | if loss is not None: 142 | batch_loss = loss(out, batch_y) 143 | eval_loss += batch_loss.item() 144 | 145 | correct = t.sum(batch_y == preds).float() 146 | batch_acc = correct / batch_x.size(0) 147 | eval_acc += batch_acc.item() 148 | 149 | count += 1. 150 | 151 | eval_acc /= count 152 | if loss is not None: 153 | eval_loss /= count 154 | 155 | return eval_loss, eval_acc 156 | 157 | return eval_acc 158 | 159 | # Step 8 Update parameters# 160 | if not only_test: 161 | train_losses = [] 162 | eval_losses = [] 163 | train_accs = [] 164 | eval_accs = [] 165 | print('\nStart training...') 166 | for epoch in tqdm.tqdm(range(epochs)): 167 | print('Epoch %d / %d.' % (epoch + 1, epochs)) 168 | print('Current learning rate : ', optimizer.state_dict()['param_groups'][0]['lr']) 169 | epoch_loss = 0. 170 | epoch_acc = 0. 171 | count = 0. 172 | for i, (batch_x, batch_y) in enumerate(train_loader): 173 | batch_x = batch_x.to(device) 174 | batch_y = batch_y.to(device) 175 | 176 | optimizer.zero_grad() 177 | out = model(batch_x) 178 | batch_loss = loss(out, batch_y) 179 | batch_loss.backward() 180 | optimizer.step() 181 | 182 | if i % batch_size == 0: 183 | count += 1. 184 | preds = t.argmax(out, dim = 1) 185 | correct = t.sum(batch_y == preds).float() 186 | acc = correct / batch_x.size(0) 187 | print('\t\033[4;33m Training Batch INFO :\033[0m Batch %d has loss : %.3f --> acc : %.2f%%.' % ( 188 | i + 1, batch_loss.item(), acc.item() * 100. 189 | )) 190 | 191 | epoch_acc += acc.item() 192 | epoch_loss += batch_loss.item() 193 | 194 | epoch_loss /= count 195 | epoch_acc /= count 196 | train_losses.append(epoch_loss) 197 | train_accs.append(epoch_acc) 198 | eval_loss, eval_acc = eval(model, 199 | eval_iter = test_loader, 200 | device = device, loss = loss) 201 | model.train() 202 | eval_losses.append(eval_loss) 203 | eval_accs.append(eval_acc) 204 | print('\033[31m Training Epoch INFO :\033[0m This epoch has train loss : %.3f --> train acc : %.2f%% || ' 205 | 'eval loss : %.3f --> eval acc : %.2f%%.' % ( 206 | epoch_loss, epoch_acc * 100., eval_loss, eval_acc * 100. 207 | )) 208 | 209 | lr_scheduler.step() 210 | 211 | 212 | print('Training is done!\n') 213 | t.save(model, './results/vgg16_' + str(attn_ratio) + '_' + attn_method + '_pre_attn_' + pre_attn + '.pth') 214 | 215 | # visualize 216 | f, ax = plt.subplots(1, 2, figsize = (20, 6)) 217 | f.suptitle('Training and eval statistics') 218 | 219 | ax[0].plot(range(len(train_losses)), train_losses, label = 'training_loss') 220 | ax[0].plot(range(len(eval_losses)), eval_losses, label = 'eval_loss') 221 | ax[0].set_xlabel('Steps') 222 | ax[0].set_ylabel('Value') 223 | ax[0].set_title('Losses') 224 | ax[0].grid(True) 225 | ax[0].legend(loc = 'best') 226 | 227 | ax[1].plot(range(len(train_accs)), train_accs, label = 'training_acc') 228 | ax[1].plot(range(len(eval_accs)), eval_accs, label = 'eval_acc') 229 | ax[1].set_xlabel('Steps') 230 | ax[1].set_ylabel('Value') 231 | ax[1].set_title('Accs') 232 | ax[1].grid(True) 233 | ax[1].legend(loc = 'best') 234 | 235 | plt.savefig('./results/vgg16_' + str(attn_ratio) + '_' + attn_method + '_pre_attn_' + pre_attn + '.png') 236 | plt.close() 237 | 238 | # Step 9 Make a test# 239 | if only_test: 240 | model = t.load('./results/vgg16_' + str(attn_ratio) + '_' + attn_method + '_pre_attn_none' + '.pth') 241 | model.to(device) 242 | 243 | final_test_acc = eval(model, test_loader, device, pre_attn = pre_attn) 244 | print('Final test acc : {:.2f}%.'.format(final_test_acc * 100.)) 245 | with open('./results/results.txt', 'a+') as f: 246 | model_string = str(datetime.datetime.now()) + ' ::: ' 247 | model_string += "[attn_ratio : " + str(attn_ratio) + ", attn_method : " + attn_method + "_pre_attn_" + pre_attn + "] --> acc : %.2f%%." % (final_test_acc * 100.) + '\n' 248 | f.write(model_string) 249 | 250 | 251 | if __name__ == '__main__': 252 | fire.Fire() 253 | 254 | """ 255 | Usage: 256 | python example.py main --device='cuda' --attn_ratio=0.5 --attn_method='se' --pre_attn='none' --epochs=100 --batch_size=20 --init_lr=0.1 --gamma=0.2 --milestones=[20,40,60,80] --weight_decay=1e-5 257 | """ 258 | 259 | 260 | print('\nDone!\n') --------------------------------------------------------------------------------