The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── LICENSE
├── MODELS
    ├── bam.py
    ├── cbam.py
    └── model_resnet.py
├── README.md
├── scripts
    ├── train_imagenet_resnet50_bam.sh
    └── train_imagenet_resnet50_cbam.sh
└── train_imagenet.py


/LICENSE:
--------------------------------------------------------------------------------
 1 | MIT License
 2 | 
 3 | Copyright (c) 2019 Jongchan Park
 4 | 
 5 | Permission is hereby granted, free of charge, to any person obtaining a copy
 6 | of this software and associated documentation files (the "Software"), to deal
 7 | in the Software without restriction, including without limitation the rights
 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 | 
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 | 
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 | 


--------------------------------------------------------------------------------
/MODELS/bam.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | import math
 3 | import torch.nn as nn
 4 | import torch.nn.functional as F
 5 | 
 6 | class Flatten(nn.Module):
 7 |     def forward(self, x):
 8 |         return x.view(x.size(0), -1)
 9 | class ChannelGate(nn.Module):
10 |     def __init__(self, gate_channel, reduction_ratio=16, num_layers=1):
11 |         super(ChannelGate, self).__init__()
12 |         self.gate_activation = gate_activation
13 |         self.gate_c = nn.Sequential()
14 |         self.gate_c.add_module( 'flatten', Flatten() )
15 |         gate_channels = [gate_channel]
16 |         gate_channels += [gate_channel // reduction_ratio] * num_layers
17 |         gate_channels += [gate_channel]
18 |         for i in range( len(gate_channels) - 2 ):
19 |             self.gate_c.add_module( 'gate_c_fc_%d'%i, nn.Linear(gate_channels[i], gate_channels[i+1]) )
20 |             self.gate_c.add_module( 'gate_c_bn_%d'%(i+1), nn.BatchNorm1d(gate_channels[i+1]) )
21 |             self.gate_c.add_module( 'gate_c_relu_%d'%(i+1), nn.ReLU() )
22 |         self.gate_c.add_module( 'gate_c_fc_final', nn.Linear(gate_channels[-2], gate_channels[-1]) )
23 |     def forward(self, in_tensor):
24 |         avg_pool = F.avg_pool2d( in_tensor, in_tensor.size(2), stride=in_tensor.size(2) )
25 |         return self.gate_c( avg_pool ).unsqueeze(2).unsqueeze(3).expand_as(in_tensor)
26 | 
27 | class SpatialGate(nn.Module):
28 |     def __init__(self, gate_channel, reduction_ratio=16, dilation_conv_num=2, dilation_val=4):
29 |         super(SpatialGate, self).__init__()
30 |         self.gate_s = nn.Sequential()
31 |         self.gate_s.add_module( 'gate_s_conv_reduce0', nn.Conv2d(gate_channel, gate_channel//reduction_ratio, kernel_size=1))
32 |         self.gate_s.add_module( 'gate_s_bn_reduce0',	nn.BatchNorm2d(gate_channel//reduction_ratio) )
33 |         self.gate_s.add_module( 'gate_s_relu_reduce0',nn.ReLU() )
34 |         for i in range( dilation_conv_num ):
35 |             self.gate_s.add_module( 'gate_s_conv_di_%d'%i, nn.Conv2d(gate_channel//reduction_ratio, gate_channel//reduction_ratio, kernel_size=3, \
36 | 						padding=dilation_val, dilation=dilation_val) )
37 |             self.gate_s.add_module( 'gate_s_bn_di_%d'%i, nn.BatchNorm2d(gate_channel//reduction_ratio) )
38 |             self.gate_s.add_module( 'gate_s_relu_di_%d'%i, nn.ReLU() )
39 |         self.gate_s.add_module( 'gate_s_conv_final', nn.Conv2d(gate_channel//reduction_ratio, 1, kernel_size=1) )
40 |     def forward(self, in_tensor):
41 |         return self.gate_s( in_tensor ).expand_as(in_tensor)
42 | class BAM(nn.Module):
43 |     def __init__(self, gate_channel):
44 |         super(BAM, self).__init__()
45 |         self.channel_att = ChannelGate(gate_channel)
46 |         self.spatial_att = SpatialGate(gate_channel)
47 |     def forward(self,in_tensor):
48 |         att = 1 + F.sigmoid( self.channel_att(in_tensor) * self.spatial_att(in_tensor) )
49 |         return att * in_tensor
50 | 


--------------------------------------------------------------------------------
/MODELS/cbam.py:
--------------------------------------------------------------------------------
 1 | import torch
 2 | import math
 3 | import torch.nn as nn
 4 | import torch.nn.functional as F
 5 | 
 6 | class BasicConv(nn.Module):
 7 |     def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
 8 |         super(BasicConv, self).__init__()
 9 |         self.out_channels = out_planes
10 |         self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
11 |         self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
12 |         self.relu = nn.ReLU() if relu else None
13 | 
14 |     def forward(self, x):
15 |         x = self.conv(x)
16 |         if self.bn is not None:
17 |             x = self.bn(x)
18 |         if self.relu is not None:
19 |             x = self.relu(x)
20 |         return x
21 | 
22 | class Flatten(nn.Module):
23 |     def forward(self, x):
24 |         return x.view(x.size(0), -1)
25 | 
26 | class ChannelGate(nn.Module):
27 |     def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
28 |         super(ChannelGate, self).__init__()
29 |         self.gate_channels = gate_channels
30 |         self.mlp = nn.Sequential(
31 |             Flatten(),
32 |             nn.Linear(gate_channels, gate_channels // reduction_ratio),
33 |             nn.ReLU(),
34 |             nn.Linear(gate_channels // reduction_ratio, gate_channels)
35 |             )
36 |         self.pool_types = pool_types
37 |     def forward(self, x):
38 |         channel_att_sum = None
39 |         for pool_type in self.pool_types:
40 |             if pool_type=='avg':
41 |                 avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
42 |                 channel_att_raw = self.mlp( avg_pool )
43 |             elif pool_type=='max':
44 |                 max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
45 |                 channel_att_raw = self.mlp( max_pool )
46 |             elif pool_type=='lp':
47 |                 lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
48 |                 channel_att_raw = self.mlp( lp_pool )
49 |             elif pool_type=='lse':
50 |                 # LSE pool only
51 |                 lse_pool = logsumexp_2d(x)
52 |                 channel_att_raw = self.mlp( lse_pool )
53 | 
54 |             if channel_att_sum is None:
55 |                 channel_att_sum = channel_att_raw
56 |             else:
57 |                 channel_att_sum = channel_att_sum + channel_att_raw
58 | 
59 |         scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
60 |         return x * scale
61 | 
62 | def logsumexp_2d(tensor):
63 |     tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
64 |     s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
65 |     outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
66 |     return outputs
67 | 
68 | class ChannelPool(nn.Module):
69 |     def forward(self, x):
70 |         return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
71 | 
72 | class SpatialGate(nn.Module):
73 |     def __init__(self):
74 |         super(SpatialGate, self).__init__()
75 |         kernel_size = 7
76 |         self.compress = ChannelPool()
77 |         self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
78 |     def forward(self, x):
79 |         x_compress = self.compress(x)
80 |         x_out = self.spatial(x_compress)
81 |         scale = F.sigmoid(x_out) # broadcasting
82 |         return x * scale
83 | 
84 | class CBAM(nn.Module):
85 |     def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
86 |         super(CBAM, self).__init__()
87 |         self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
88 |         self.no_spatial=no_spatial
89 |         if not no_spatial:
90 |             self.SpatialGate = SpatialGate()
91 |     def forward(self, x):
92 |         x_out = self.ChannelGate(x)
93 |         if not self.no_spatial:
94 |             x_out = self.SpatialGate(x_out)
95 |         return x_out
96 | 


--------------------------------------------------------------------------------
/MODELS/model_resnet.py:
--------------------------------------------------------------------------------
  1 | import torch
  2 | import torch.nn as nn
  3 | import torch.nn.functional as F
  4 | import math
  5 | from torch.nn import init
  6 | from .cbam import *
  7 | from .bam import *
  8 | 
  9 | def conv3x3(in_planes, out_planes, stride=1):
 10 |     "3x3 convolution with padding"
 11 |     return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
 12 |                      padding=1, bias=False)
 13 | 
 14 | class BasicBlock(nn.Module):
 15 |     expansion = 1
 16 | 
 17 |     def __init__(self, inplanes, planes, stride=1, downsample=None, use_cbam=False):
 18 |         super(BasicBlock, self).__init__()
 19 |         self.conv1 = conv3x3(inplanes, planes, stride)
 20 |         self.bn1 = nn.BatchNorm2d(planes)
 21 |         self.relu = nn.ReLU(inplace=True)
 22 |         self.conv2 = conv3x3(planes, planes)
 23 |         self.bn2 = nn.BatchNorm2d(planes)
 24 |         self.downsample = downsample
 25 |         self.stride = stride
 26 | 
 27 |         if use_cbam:
 28 |             self.cbam = CBAM( planes, 16 )
 29 |         else:
 30 |             self.cbam = None
 31 | 
 32 |     def forward(self, x):
 33 |         residual = x
 34 | 
 35 |         out = self.conv1(x)
 36 |         out = self.bn1(out)
 37 |         out = self.relu(out)
 38 | 
 39 |         out = self.conv2(out)
 40 |         out = self.bn2(out)
 41 | 
 42 |         if self.downsample is not None:
 43 |             residual = self.downsample(x)
 44 | 
 45 |         if not self.cbam is None:
 46 |             out = self.cbam(out)
 47 | 
 48 |         out += residual
 49 |         out = self.relu(out)
 50 | 
 51 |         return out
 52 | 
 53 | class Bottleneck(nn.Module):
 54 |     expansion = 4
 55 | 
 56 |     def __init__(self, inplanes, planes, stride=1, downsample=None, use_cbam=False):
 57 |         super(Bottleneck, self).__init__()
 58 |         self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
 59 |         self.bn1 = nn.BatchNorm2d(planes)
 60 |         self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
 61 |                                padding=1, bias=False)
 62 |         self.bn2 = nn.BatchNorm2d(planes)
 63 |         self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
 64 |         self.bn3 = nn.BatchNorm2d(planes * 4)
 65 |         self.relu = nn.ReLU(inplace=True)
 66 |         self.downsample = downsample
 67 |         self.stride = stride
 68 | 
 69 |         if use_cbam:
 70 |             self.cbam = CBAM( planes * 4, 16 )
 71 |         else:
 72 |             self.cbam = None
 73 | 
 74 |     def forward(self, x):
 75 |         residual = x
 76 | 
 77 |         out = self.conv1(x)
 78 |         out = self.bn1(out)
 79 |         out = self.relu(out)
 80 | 
 81 |         out = self.conv2(out)
 82 |         out = self.bn2(out)
 83 |         out = self.relu(out)
 84 | 
 85 |         out = self.conv3(out)
 86 |         out = self.bn3(out)
 87 | 
 88 |         if self.downsample is not None:
 89 |             residual = self.downsample(x)
 90 | 
 91 |         if not self.cbam is None:
 92 |             out = self.cbam(out)
 93 | 
 94 |         out += residual
 95 |         out = self.relu(out)
 96 | 
 97 |         return out
 98 | 
 99 | class ResNet(nn.Module):
100 |     def __init__(self, block, layers,  network_type, num_classes, att_type=None):
101 |         self.inplanes = 64
102 |         super(ResNet, self).__init__()
103 |         self.network_type = network_type
104 |         # different model config between ImageNet and CIFAR 
105 |         if network_type == "ImageNet":
106 |             self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
107 |             self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
108 |             self.avgpool = nn.AvgPool2d(7)
109 |         else:
110 |             self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
111 | 
112 |         self.bn1 = nn.BatchNorm2d(64)
113 |         self.relu = nn.ReLU(inplace=True)
114 | 
115 |         if att_type=='BAM':
116 |             self.bam1 = BAM(64*block.expansion)
117 |             self.bam2 = BAM(128*block.expansion)
118 |             self.bam3 = BAM(256*block.expansion)
119 |         else:
120 |             self.bam1, self.bam2, self.bam3 = None, None, None
121 | 
122 |         self.layer1 = self._make_layer(block, 64,  layers[0], att_type=att_type)
123 |         self.layer2 = self._make_layer(block, 128, layers[1], stride=2, att_type=att_type)
124 |         self.layer3 = self._make_layer(block, 256, layers[2], stride=2, att_type=att_type)
125 |         self.layer4 = self._make_layer(block, 512, layers[3], stride=2, att_type=att_type)
126 | 
127 |         self.fc = nn.Linear(512 * block.expansion, num_classes)
128 | 
129 |         init.kaiming_normal(self.fc.weight)
130 |         for key in self.state_dict():
131 |             if key.split('.')[-1]=="weight":
132 |                 if "conv" in key:
133 |                     init.kaiming_normal(self.state_dict()[key], mode='fan_out')
134 |                 if "bn" in key:
135 |                     if "SpatialGate" in key:
136 |                         self.state_dict()[key][...] = 0
137 |                     else:
138 |                         self.state_dict()[key][...] = 1
139 |             elif key.split(".")[-1]=='bias':
140 |                 self.state_dict()[key][...] = 0
141 | 
142 |     def _make_layer(self, block, planes, blocks, stride=1, att_type=None):
143 |         downsample = None
144 |         if stride != 1 or self.inplanes != planes * block.expansion:
145 |             downsample = nn.Sequential(
146 |                 nn.Conv2d(self.inplanes, planes * block.expansion,
147 |                           kernel_size=1, stride=stride, bias=False),
148 |                 nn.BatchNorm2d(planes * block.expansion),
149 |             )
150 | 
151 |         layers = []
152 |         layers.append(block(self.inplanes, planes, stride, downsample, use_cbam=att_type=='CBAM'))
153 |         self.inplanes = planes * block.expansion
154 |         for i in range(1, blocks):
155 |             layers.append(block(self.inplanes, planes, use_cbam=att_type=='CBAM'))
156 | 
157 |         return nn.Sequential(*layers)
158 | 
159 |     def forward(self, x):
160 |         x = self.conv1(x)
161 |         x = self.bn1(x)
162 |         x = self.relu(x)
163 |         if self.network_type == "ImageNet":
164 |             x = self.maxpool(x)
165 | 
166 |         x = self.layer1(x)
167 |         if not self.bam1 is None:
168 |             x = self.bam1(x)
169 | 
170 |         x = self.layer2(x)
171 |         if not self.bam2 is None:
172 |             x = self.bam2(x)
173 | 
174 |         x = self.layer3(x)
175 |         if not self.bam3 is None:
176 |             x = self.bam3(x)
177 | 
178 |         x = self.layer4(x)
179 | 
180 |         if self.network_type == "ImageNet":
181 |             x = self.avgpool(x)
182 |         else:
183 |             x = F.avg_pool2d(x, 4)
184 |         x = x.view(x.size(0), -1)
185 |         x = self.fc(x)
186 |         return x
187 | 
188 | def ResidualNet(network_type, depth, num_classes, att_type):
189 | 
190 |     assert network_type in ["ImageNet", "CIFAR10", "CIFAR100"], "network type should be ImageNet or CIFAR10 / CIFAR100"
191 |     assert depth in [18, 34, 50, 101], 'network depth should be 18, 34, 50 or 101'
192 | 
193 |     if depth == 18:
194 |         model = ResNet(BasicBlock, [2, 2, 2, 2], network_type, num_classes, att_type)
195 | 
196 |     elif depth == 34:
197 |         model = ResNet(BasicBlock, [3, 4, 6, 3], network_type, num_classes, att_type)
198 | 
199 |     elif depth == 50:
200 |         model = ResNet(Bottleneck, [3, 4, 6, 3], network_type, num_classes, att_type)
201 | 
202 |     elif depth == 101:
203 |         model = ResNet(Bottleneck, [3, 4, 23, 3], network_type, num_classes, att_type)
204 | 
205 |     return model
206 | 


--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
 1 | # BAM and CBAM
 2 | Official PyTorch code for "[BAM: Bottleneck Attention Module (BMVC2018)](http://bmvc2018.org/contents/papers/0092.pdf)" and "[CBAM: Convolutional Block Attention Module (ECCV2018)](http://openaccess.thecvf.com/content_ECCV_2018/html/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.html)"
 3 | 
 4 | ### Updates & Notices
 5 | - 2018-10-08: ~~Currently, only CBAM test code is validated. **There may be minor errors in the training code**. Will be fixed in a few days.~~
 6 | - 2018-10-11: Training code validated. RESNET50+BAM pretrained weight added.
 7 | 
 8 | ### Requirement
 9 | 
10 | The code is validated under below environment:
11 | - Ubuntu 16.04, 4*GTX 1080 Ti, Docker (PyTorch 0.4.1, CUDA 9.0 + CuDNN 7.0, Python 3.6)
12 | 
13 | ### How to use
14 | 
15 | ResNet50 based examples are included. Example scripts are included under ```./scripts/``` directory.
16 | ImageNet data should be included under ```./data/ImageNet/``` with foler named ```train``` and ```val```.
17 | 
18 | ```
19 | # To train with BAM (ResNet50 backbone)
20 | python train_imagenet.py --ngpu 4 --workers 20 --arch resnet --depth 50 --epochs 100 --batch-size 256 --lr 0.1 --att-type BAM --prefix RESNET50_IMAGENET_BAM ./data/ImageNet
21 | # To train with CBAM (ResNet50 backbone)
22 | python train_imagenet.py --ngpu 4 --workers 20 --arch resnet --depth 50 --epochs 100 --batch-size 256 --lr 0.1 --att-type CBAM --prefix RESNET50_IMAGENET_CBAM ./data/ImageNet
23 | ```
24 | 
25 | ### Resume with checkpoints
26 | 
27 | - ResNet50+CBAM (trained for 100 epochs) checkpoint is provided in this [link](https://drive.google.com/file/d/1mvAVvhLR_2XY_bPYxh-SEz4vDmGzSArO/view?usp=sharing). ACC@1=77.622 ACC@5=93.948
28 | - ResNet50+BAM (trained for 90 epochs) checkpoint is provided in this [link](https://drive.google.com/file/d/1auVf70gfL0ol40bvaX5rlbpn9cKIxhAL/view?usp=sharing). ACC@1=76.860 ACC@5=93.416
29 | 
30 | For validation, please use the script as follows
31 | ```
32 | python train_imagenet.py --ngpu 4 --workers 20 --arch resnet --depth 50 --att-type CBAM --prefix EVAL --resume $CHECKPOINT_PATH$ --evaluate ./data/ImageNet
33 | ```
34 | 
35 | ### Other implementations
36 | 
37 | - [MXNet implementation of CBAM with several modifications](https://github.com/bruinxiong/Modified-CBAMnet.mxnet) by [bruinxiong](https://github.com/bruinxiong)
38 | 


--------------------------------------------------------------------------------
/scripts/train_imagenet_resnet50_bam.sh:
--------------------------------------------------------------------------------
 1 | python train_imagenet.py \
 2 | 			--ngpu 8 \
 3 | 			--workers 20 \
 4 | 			--arch resnet --depth 50 \
 5 | 			--epochs 100 \
 6 | 			--batch-size 256 --lr 0.1 \
 7 | 			--att-type BAM \
 8 | 			--prefix RESNET50_IMAGENET_BAM \
 9 | 			./data/ImageNet/
10 | 


--------------------------------------------------------------------------------
/scripts/train_imagenet_resnet50_cbam.sh:
--------------------------------------------------------------------------------
 1 | python train_imagenet.py \
 2 | 			--ngpu 8 \
 3 | 			--workers 20 \
 4 | 			--arch resnet --depth 50 \
 5 | 			--epochs 100 \
 6 | 			--batch-size 256 --lr 0.1 \
 7 | 			--att-type CBAM \
 8 | 			--prefix RESNET50_IMAGENET_CBAM \
 9 | 			./data/ImageNet/
10 | 


--------------------------------------------------------------------------------
/train_imagenet.py:
--------------------------------------------------------------------------------
  1 | import argparse
  2 | import os
  3 | import shutil
  4 | import time
  5 | import random
  6 | 
  7 | import torch
  8 | import torch.nn as nn
  9 | import torch.nn.parallel
 10 | import torch.backends.cudnn as cudnn
 11 | import torch.optim
 12 | import torch.utils.data
 13 | import torchvision.transforms as transforms
 14 | import torchvision.datasets as datasets
 15 | import torchvision.models as models
 16 | from MODELS.model_resnet import *
 17 | from PIL import ImageFile
 18 | ImageFile.LOAD_TRUNCATED_IMAGES = True
 19 | model_names = sorted(name for name in models.__dict__
 20 |     if name.islower() and not name.startswith("__")
 21 |     and callable(models.__dict__[name]))
 22 | 
 23 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
 24 | parser.add_argument('data', metavar='DIR',
 25 |                     help='path to dataset')
 26 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet',
 27 |                     help='model architecture: ' +
 28 |                         ' | '.join(model_names) +
 29 |                         ' (default: resnet18)')
 30 | parser.add_argument('--depth', default=50, type=int, metavar='D',
 31 |                     help='model depth')
 32 | parser.add_argument('--ngpu', default=4, type=int, metavar='G',
 33 |                     help='number of gpus to use')
 34 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
 35 |                     help='number of data loading workers (default: 4)')
 36 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
 37 |                     help='number of total epochs to run')
 38 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
 39 |                     help='manual epoch number (useful on restarts)')
 40 | parser.add_argument('-b', '--batch-size', default=256, type=int,
 41 |                     metavar='N', help='mini-batch size (default: 256)')
 42 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
 43 |                     metavar='LR', help='initial learning rate')
 44 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
 45 |                     help='momentum')
 46 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
 47 |                     metavar='W', help='weight decay (default: 1e-4)')
 48 | parser.add_argument('--print-freq', '-p', default=10, type=int,
 49 |                     metavar='N', help='print frequency (default: 10)')
 50 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
 51 |                     help='path to latest checkpoint (default: none)')
 52 | parser.add_argument("--seed", type=int, default=1234, metavar='BS', help='input batch size for training (default: 64)')
 53 | parser.add_argument("--prefix", type=str, required=True, metavar='PFX', help='prefix for logging & checkpoint saving')
 54 | parser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluation only')
 55 | parser.add_argument('--att-type', type=str, choices=['BAM', 'CBAM'], default=None)
 56 | best_prec1 = 0
 57 | 
 58 | if not os.path.exists('./checkpoints'):
 59 |     os.mkdir('./checkpoints')
 60 | 
 61 | def main():
 62 |     global args, best_prec1
 63 |     global viz, train_lot, test_lot
 64 |     args = parser.parse_args()
 65 |     print ("args", args)
 66 | 
 67 |     torch.manual_seed(args.seed)
 68 |     torch.cuda.manual_seed_all(args.seed)
 69 |     random.seed(args.seed)
 70 | 
 71 |     # create model
 72 |     if args.arch == "resnet":
 73 |         model = ResidualNet( 'ImageNet', args.depth, 1000, args.att_type )
 74 | 
 75 |     # define loss function (criterion) and optimizer
 76 |     criterion = nn.CrossEntropyLoss().cuda()
 77 | 
 78 |     optimizer = torch.optim.SGD(model.parameters(), args.lr,
 79 |                             momentum=args.momentum,
 80 |                             weight_decay=args.weight_decay)
 81 |     model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
 82 |     #model = torch.nn.DataParallel(model).cuda()
 83 |     model = model.cuda()
 84 |     print ("model")
 85 |     print (model)
 86 | 
 87 |     # get the number of model parameters
 88 |     print('Number of model parameters: {}'.format(
 89 |         sum([p.data.nelement() for p in model.parameters()])))
 90 | 
 91 |     # optionally resume from a checkpoint
 92 |     if args.resume:
 93 |         if os.path.isfile(args.resume):
 94 |             print("=> loading checkpoint '{}'".format(args.resume))
 95 |             checkpoint = torch.load(args.resume)
 96 |             args.start_epoch = checkpoint['epoch']
 97 |             best_prec1 = checkpoint['best_prec1']
 98 |             model.load_state_dict(checkpoint['state_dict'])
 99 |             if 'optimizer' in checkpoint:
100 |                 optimizer.load_state_dict(checkpoint['optimizer'])
101 |             print("=> loaded checkpoint '{}' (epoch {})"
102 |                   .format(args.resume, checkpoint['epoch']))
103 |         else:
104 |             print("=> no checkpoint found at '{}'".format(args.resume))
105 | 
106 | 
107 |     cudnn.benchmark = True
108 | 
109 |     # Data loading code
110 |     traindir = os.path.join(args.data, 'train')
111 |     valdir = os.path.join(args.data, 'val')
112 |     normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
113 |                                  std=[0.229, 0.224, 0.225])
114 | 
115 |     # import pdb
116 |     # pdb.set_trace()
117 |     val_loader = torch.utils.data.DataLoader(
118 |         datasets.ImageFolder(valdir, transforms.Compose([
119 |                 transforms.Scale(256),
120 |                 transforms.CenterCrop(224),
121 |                 transforms.ToTensor(),
122 |                 normalize,
123 |                 ])),
124 |             batch_size=args.batch_size, shuffle=False,
125 |            num_workers=args.workers, pin_memory=True)
126 |     if args.evaluate:
127 |         validate(val_loader, model, criterion, 0)
128 |         return
129 | 
130 |     train_dataset = datasets.ImageFolder(
131 |         traindir,
132 |         transforms.Compose([
133 |             transforms.RandomSizedCrop(size0),
134 |             transforms.RandomHorizontalFlip(),
135 |             transforms.ToTensor(),
136 |             normalize,
137 |         ]))
138 | 
139 |     train_sampler = None
140 | 
141 |     train_loader = torch.utils.data.DataLoader(
142 |         train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
143 |         num_workers=args.workers, pin_memory=True, sampler=train_sampler)
144 | 
145 |     for epoch in range(args.start_epoch, args.epochs):
146 |         adjust_learning_rate(optimizer, epoch)
147 |         
148 |         # train for one epoch
149 |         train(train_loader, model, criterion, optimizer, epoch)
150 |         
151 |         # evaluate on validation set
152 |         prec1 = validate(val_loader, model, criterion, epoch)
153 |         
154 |         # remember best prec@1 and save checkpoint
155 |         is_best = prec1 > best_prec1
156 |         best_prec1 = max(prec1, best_prec1)
157 |         save_checkpoint({
158 |             'epoch': epoch + 1,
159 |             'arch': args.arch,
160 |             'state_dict': model.state_dict(),
161 |             'best_prec1': best_prec1,
162 |             'optimizer' : optimizer.state_dict(),
163 |         }, is_best, args.prefix)
164 | 
165 | 
166 | def train(train_loader, model, criterion, optimizer, epoch):
167 |     batch_time = AverageMeter()
168 |     data_time = AverageMeter()
169 |     losses = AverageMeter()
170 |     top1 = AverageMeter()
171 |     top5 = AverageMeter()
172 | 
173 |     # switch to train mode
174 |     model.train()
175 | 
176 |     end = time.time()
177 |     for i, (input, target) in enumerate(train_loader):
178 |         # measure data loading time
179 |         data_time.update(time.time() - end)
180 |         
181 |         target = target.cuda(async=True)
182 |         input_var = torch.autograd.Variable(input)
183 |         target_var = torch.autograd.Variable(target)
184 |         
185 |         # compute output
186 |         output = model(input_var)
187 |         loss = criterion(output, target_var)
188 |         
189 |         # measure accuracy and record loss
190 |         prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
191 |         losses.update(loss.data[0], input.size(0))
192 |         top1.update(prec1[0], input.size(0))
193 |         top5.update(prec5[0], input.size(0))
194 |         
195 |         # compute gradient and do SGD step
196 |         optimizer.zero_grad()
197 |         loss.backward()
198 |         optimizer.step()
199 |         
200 |         # measure elapsed time
201 |         batch_time.update(time.time() - end)
202 |         end = time.time()
203 |         
204 |         if i % args.print_freq == 0:
205 |             print('Epoch: [{0}][{1}/{2}]\t'
206 |                   'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
207 |                   'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
208 |                   'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
209 |                   'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
210 |                   'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
211 |                    epoch, i, len(train_loader), batch_time=batch_time,
212 |                    data_time=data_time, loss=losses, top1=top1, top5=top5))
213 | 
214 | def validate(val_loader, model, criterion, epoch):
215 |     batch_time = AverageMeter()
216 |     losses = AverageMeter()
217 |     top1 = AverageMeter()
218 |     top5 = AverageMeter()
219 | 
220 |     # switch to evaluate mode
221 |     model.eval()
222 | 
223 |     end = time.time()
224 |     for i, (input, target) in enumerate(val_loader):
225 |         target = target.cuda(async=True)
226 |         input_var = torch.autograd.Variable(input, volatile=True)
227 |         target_var = torch.autograd.Variable(target, volatile=True)
228 |         
229 |         # compute output
230 |         output = model(input_var)
231 |         loss = criterion(output, target_var)
232 |         
233 |         # measure accuracy and record loss
234 |         prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
235 |         losses.update(loss.data[0], input.size(0))
236 |         top1.update(prec1[0], input.size(0))
237 |         top5.update(prec5[0], input.size(0))
238 |         
239 |         # measure elapsed time
240 |         batch_time.update(time.time() - end)
241 |         end = time.time()
242 |         
243 |         if i % args.print_freq == 0:
244 |             print('Test: [{0}/{1}]\t'
245 |                   'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
246 |                   'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
247 |                   'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
248 |                   'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
249 |                    i, len(val_loader), batch_time=batch_time, loss=losses,
250 |                    top1=top1, top5=top5))
251 |     
252 |     print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
253 |             .format(top1=top1, top5=top5))
254 | 
255 |     return top1.avg
256 | 
257 | 
258 | def save_checkpoint(state, is_best, prefix):
259 |     filename='./checkpoints/%s_checkpoint.pth.tar'%prefix
260 |     torch.save(state, filename)
261 |     if is_best:
262 |         shutil.copyfile(filename, './checkpoints/%s_model_best.pth.tar'%prefix)
263 | 
264 | 
265 | class AverageMeter(object):
266 |     """Computes and stores the average and current value"""
267 |     def __init__(self):
268 |         self.reset()
269 | 
270 |     def reset(self):
271 |         self.val = 0
272 |         self.avg = 0
273 |         self.sum = 0
274 |         self.count = 0
275 | 
276 |     def update(self, val, n=1):
277 |         self.val = val
278 |         self.sum += val * n
279 |         self.count += n
280 |         self.avg = self.sum / self.count
281 | 
282 | 
283 | def adjust_learning_rate(optimizer, epoch):
284 |     """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
285 |     lr = args.lr * (0.1 ** (epoch // 30))
286 |     for param_group in optimizer.param_groups:
287 |         param_group['lr'] = lr
288 | 
289 | 
290 | def accuracy(output, target, topk=(1,)):
291 |     """Computes the precision@k for the specified values of k"""
292 |     maxk = max(topk)
293 |     batch_size = target.size(0)
294 | 
295 |     _, pred = output.topk(maxk, 1, True, True)
296 |     pred = pred.t()
297 |     correct = pred.eq(target.view(1, -1).expand_as(pred))
298 | 
299 |     res = []
300 |     for k in topk:
301 |         correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
302 |         res.append(correct_k.mul_(100.0 / batch_size))
303 |     return res
304 | 
305 | 
306 | if __name__ == '__main__':
307 |     main()
308 | 


--------------------------------------------------------------------------------