├── LICENSE
├── MODELS
├── bam.py
├── cbam.py
└── model_resnet.py
├── README.md
├── scripts
├── train_imagenet_resnet50_bam.sh
└── train_imagenet_resnet50_cbam.sh
└── train_imagenet.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Jongchan Park
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MODELS/bam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | class Flatten(nn.Module):
7 | def forward(self, x):
8 | return x.view(x.size(0), -1)
9 | class ChannelGate(nn.Module):
10 | def __init__(self, gate_channel, reduction_ratio=16, num_layers=1):
11 | super(ChannelGate, self).__init__()
12 | self.gate_activation = gate_activation
13 | self.gate_c = nn.Sequential()
14 | self.gate_c.add_module( 'flatten', Flatten() )
15 | gate_channels = [gate_channel]
16 | gate_channels += [gate_channel // reduction_ratio] * num_layers
17 | gate_channels += [gate_channel]
18 | for i in range( len(gate_channels) - 2 ):
19 | self.gate_c.add_module( 'gate_c_fc_%d'%i, nn.Linear(gate_channels[i], gate_channels[i+1]) )
20 | self.gate_c.add_module( 'gate_c_bn_%d'%(i+1), nn.BatchNorm1d(gate_channels[i+1]) )
21 | self.gate_c.add_module( 'gate_c_relu_%d'%(i+1), nn.ReLU() )
22 | self.gate_c.add_module( 'gate_c_fc_final', nn.Linear(gate_channels[-2], gate_channels[-1]) )
23 | def forward(self, in_tensor):
24 | avg_pool = F.avg_pool2d( in_tensor, in_tensor.size(2), stride=in_tensor.size(2) )
25 | return self.gate_c( avg_pool ).unsqueeze(2).unsqueeze(3).expand_as(in_tensor)
26 |
27 | class SpatialGate(nn.Module):
28 | def __init__(self, gate_channel, reduction_ratio=16, dilation_conv_num=2, dilation_val=4):
29 | super(SpatialGate, self).__init__()
30 | self.gate_s = nn.Sequential()
31 | self.gate_s.add_module( 'gate_s_conv_reduce0', nn.Conv2d(gate_channel, gate_channel//reduction_ratio, kernel_size=1))
32 | self.gate_s.add_module( 'gate_s_bn_reduce0', nn.BatchNorm2d(gate_channel//reduction_ratio) )
33 | self.gate_s.add_module( 'gate_s_relu_reduce0',nn.ReLU() )
34 | for i in range( dilation_conv_num ):
35 | self.gate_s.add_module( 'gate_s_conv_di_%d'%i, nn.Conv2d(gate_channel//reduction_ratio, gate_channel//reduction_ratio, kernel_size=3, \
36 | padding=dilation_val, dilation=dilation_val) )
37 | self.gate_s.add_module( 'gate_s_bn_di_%d'%i, nn.BatchNorm2d(gate_channel//reduction_ratio) )
38 | self.gate_s.add_module( 'gate_s_relu_di_%d'%i, nn.ReLU() )
39 | self.gate_s.add_module( 'gate_s_conv_final', nn.Conv2d(gate_channel//reduction_ratio, 1, kernel_size=1) )
40 | def forward(self, in_tensor):
41 | return self.gate_s( in_tensor ).expand_as(in_tensor)
42 | class BAM(nn.Module):
43 | def __init__(self, gate_channel):
44 | super(BAM, self).__init__()
45 | self.channel_att = ChannelGate(gate_channel)
46 | self.spatial_att = SpatialGate(gate_channel)
47 | def forward(self,in_tensor):
48 | att = 1 + F.sigmoid( self.channel_att(in_tensor) * self.spatial_att(in_tensor) )
49 | return att * in_tensor
50 |
--------------------------------------------------------------------------------
/MODELS/cbam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | class BasicConv(nn.Module):
7 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
8 | super(BasicConv, self).__init__()
9 | self.out_channels = out_planes
10 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
11 | self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
12 | self.relu = nn.ReLU() if relu else None
13 |
14 | def forward(self, x):
15 | x = self.conv(x)
16 | if self.bn is not None:
17 | x = self.bn(x)
18 | if self.relu is not None:
19 | x = self.relu(x)
20 | return x
21 |
22 | class Flatten(nn.Module):
23 | def forward(self, x):
24 | return x.view(x.size(0), -1)
25 |
26 | class ChannelGate(nn.Module):
27 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
28 | super(ChannelGate, self).__init__()
29 | self.gate_channels = gate_channels
30 | self.mlp = nn.Sequential(
31 | Flatten(),
32 | nn.Linear(gate_channels, gate_channels // reduction_ratio),
33 | nn.ReLU(),
34 | nn.Linear(gate_channels // reduction_ratio, gate_channels)
35 | )
36 | self.pool_types = pool_types
37 | def forward(self, x):
38 | channel_att_sum = None
39 | for pool_type in self.pool_types:
40 | if pool_type=='avg':
41 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
42 | channel_att_raw = self.mlp( avg_pool )
43 | elif pool_type=='max':
44 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
45 | channel_att_raw = self.mlp( max_pool )
46 | elif pool_type=='lp':
47 | lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
48 | channel_att_raw = self.mlp( lp_pool )
49 | elif pool_type=='lse':
50 | # LSE pool only
51 | lse_pool = logsumexp_2d(x)
52 | channel_att_raw = self.mlp( lse_pool )
53 |
54 | if channel_att_sum is None:
55 | channel_att_sum = channel_att_raw
56 | else:
57 | channel_att_sum = channel_att_sum + channel_att_raw
58 |
59 | scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
60 | return x * scale
61 |
62 | def logsumexp_2d(tensor):
63 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
64 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
65 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
66 | return outputs
67 |
68 | class ChannelPool(nn.Module):
69 | def forward(self, x):
70 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
71 |
72 | class SpatialGate(nn.Module):
73 | def __init__(self):
74 | super(SpatialGate, self).__init__()
75 | kernel_size = 7
76 | self.compress = ChannelPool()
77 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
78 | def forward(self, x):
79 | x_compress = self.compress(x)
80 | x_out = self.spatial(x_compress)
81 | scale = F.sigmoid(x_out) # broadcasting
82 | return x * scale
83 |
84 | class CBAM(nn.Module):
85 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
86 | super(CBAM, self).__init__()
87 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
88 | self.no_spatial=no_spatial
89 | if not no_spatial:
90 | self.SpatialGate = SpatialGate()
91 | def forward(self, x):
92 | x_out = self.ChannelGate(x)
93 | if not self.no_spatial:
94 | x_out = self.SpatialGate(x_out)
95 | return x_out
96 |
--------------------------------------------------------------------------------
/MODELS/model_resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | from torch.nn import init
6 | from .cbam import *
7 | from .bam import *
8 |
9 | def conv3x3(in_planes, out_planes, stride=1):
10 | "3x3 convolution with padding"
11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
12 | padding=1, bias=False)
13 |
14 | class BasicBlock(nn.Module):
15 | expansion = 1
16 |
17 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_cbam=False):
18 | super(BasicBlock, self).__init__()
19 | self.conv1 = conv3x3(inplanes, planes, stride)
20 | self.bn1 = nn.BatchNorm2d(planes)
21 | self.relu = nn.ReLU(inplace=True)
22 | self.conv2 = conv3x3(planes, planes)
23 | self.bn2 = nn.BatchNorm2d(planes)
24 | self.downsample = downsample
25 | self.stride = stride
26 |
27 | if use_cbam:
28 | self.cbam = CBAM( planes, 16 )
29 | else:
30 | self.cbam = None
31 |
32 | def forward(self, x):
33 | residual = x
34 |
35 | out = self.conv1(x)
36 | out = self.bn1(out)
37 | out = self.relu(out)
38 |
39 | out = self.conv2(out)
40 | out = self.bn2(out)
41 |
42 | if self.downsample is not None:
43 | residual = self.downsample(x)
44 |
45 | if not self.cbam is None:
46 | out = self.cbam(out)
47 |
48 | out += residual
49 | out = self.relu(out)
50 |
51 | return out
52 |
53 | class Bottleneck(nn.Module):
54 | expansion = 4
55 |
56 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_cbam=False):
57 | super(Bottleneck, self).__init__()
58 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
59 | self.bn1 = nn.BatchNorm2d(planes)
60 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
61 | padding=1, bias=False)
62 | self.bn2 = nn.BatchNorm2d(planes)
63 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
64 | self.bn3 = nn.BatchNorm2d(planes * 4)
65 | self.relu = nn.ReLU(inplace=True)
66 | self.downsample = downsample
67 | self.stride = stride
68 |
69 | if use_cbam:
70 | self.cbam = CBAM( planes * 4, 16 )
71 | else:
72 | self.cbam = None
73 |
74 | def forward(self, x):
75 | residual = x
76 |
77 | out = self.conv1(x)
78 | out = self.bn1(out)
79 | out = self.relu(out)
80 |
81 | out = self.conv2(out)
82 | out = self.bn2(out)
83 | out = self.relu(out)
84 |
85 | out = self.conv3(out)
86 | out = self.bn3(out)
87 |
88 | if self.downsample is not None:
89 | residual = self.downsample(x)
90 |
91 | if not self.cbam is None:
92 | out = self.cbam(out)
93 |
94 | out += residual
95 | out = self.relu(out)
96 |
97 | return out
98 |
99 | class ResNet(nn.Module):
100 | def __init__(self, block, layers, network_type, num_classes, att_type=None):
101 | self.inplanes = 64
102 | super(ResNet, self).__init__()
103 | self.network_type = network_type
104 | # different model config between ImageNet and CIFAR
105 | if network_type == "ImageNet":
106 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
107 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
108 | self.avgpool = nn.AvgPool2d(7)
109 | else:
110 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
111 |
112 | self.bn1 = nn.BatchNorm2d(64)
113 | self.relu = nn.ReLU(inplace=True)
114 |
115 | if att_type=='BAM':
116 | self.bam1 = BAM(64*block.expansion)
117 | self.bam2 = BAM(128*block.expansion)
118 | self.bam3 = BAM(256*block.expansion)
119 | else:
120 | self.bam1, self.bam2, self.bam3 = None, None, None
121 |
122 | self.layer1 = self._make_layer(block, 64, layers[0], att_type=att_type)
123 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, att_type=att_type)
124 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, att_type=att_type)
125 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, att_type=att_type)
126 |
127 | self.fc = nn.Linear(512 * block.expansion, num_classes)
128 |
129 | init.kaiming_normal(self.fc.weight)
130 | for key in self.state_dict():
131 | if key.split('.')[-1]=="weight":
132 | if "conv" in key:
133 | init.kaiming_normal(self.state_dict()[key], mode='fan_out')
134 | if "bn" in key:
135 | if "SpatialGate" in key:
136 | self.state_dict()[key][...] = 0
137 | else:
138 | self.state_dict()[key][...] = 1
139 | elif key.split(".")[-1]=='bias':
140 | self.state_dict()[key][...] = 0
141 |
142 | def _make_layer(self, block, planes, blocks, stride=1, att_type=None):
143 | downsample = None
144 | if stride != 1 or self.inplanes != planes * block.expansion:
145 | downsample = nn.Sequential(
146 | nn.Conv2d(self.inplanes, planes * block.expansion,
147 | kernel_size=1, stride=stride, bias=False),
148 | nn.BatchNorm2d(planes * block.expansion),
149 | )
150 |
151 | layers = []
152 | layers.append(block(self.inplanes, planes, stride, downsample, use_cbam=att_type=='CBAM'))
153 | self.inplanes = planes * block.expansion
154 | for i in range(1, blocks):
155 | layers.append(block(self.inplanes, planes, use_cbam=att_type=='CBAM'))
156 |
157 | return nn.Sequential(*layers)
158 |
159 | def forward(self, x):
160 | x = self.conv1(x)
161 | x = self.bn1(x)
162 | x = self.relu(x)
163 | if self.network_type == "ImageNet":
164 | x = self.maxpool(x)
165 |
166 | x = self.layer1(x)
167 | if not self.bam1 is None:
168 | x = self.bam1(x)
169 |
170 | x = self.layer2(x)
171 | if not self.bam2 is None:
172 | x = self.bam2(x)
173 |
174 | x = self.layer3(x)
175 | if not self.bam3 is None:
176 | x = self.bam3(x)
177 |
178 | x = self.layer4(x)
179 |
180 | if self.network_type == "ImageNet":
181 | x = self.avgpool(x)
182 | else:
183 | x = F.avg_pool2d(x, 4)
184 | x = x.view(x.size(0), -1)
185 | x = self.fc(x)
186 | return x
187 |
188 | def ResidualNet(network_type, depth, num_classes, att_type):
189 |
190 | assert network_type in ["ImageNet", "CIFAR10", "CIFAR100"], "network type should be ImageNet or CIFAR10 / CIFAR100"
191 | assert depth in [18, 34, 50, 101], 'network depth should be 18, 34, 50 or 101'
192 |
193 | if depth == 18:
194 | model = ResNet(BasicBlock, [2, 2, 2, 2], network_type, num_classes, att_type)
195 |
196 | elif depth == 34:
197 | model = ResNet(BasicBlock, [3, 4, 6, 3], network_type, num_classes, att_type)
198 |
199 | elif depth == 50:
200 | model = ResNet(Bottleneck, [3, 4, 6, 3], network_type, num_classes, att_type)
201 |
202 | elif depth == 101:
203 | model = ResNet(Bottleneck, [3, 4, 23, 3], network_type, num_classes, att_type)
204 |
205 | return model
206 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # BAM and CBAM
2 | Official PyTorch code for "[BAM: Bottleneck Attention Module (BMVC2018)](http://bmvc2018.org/contents/papers/0092.pdf)" and "[CBAM: Convolutional Block Attention Module (ECCV2018)](http://openaccess.thecvf.com/content_ECCV_2018/html/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.html)"
3 |
4 | ### Updates & Notices
5 | - 2018-10-08: ~~Currently, only CBAM test code is validated. **There may be minor errors in the training code**. Will be fixed in a few days.~~
6 | - 2018-10-11: Training code validated. RESNET50+BAM pretrained weight added.
7 |
8 | ### Requirement
9 |
10 | The code is validated under below environment:
11 | - Ubuntu 16.04, 4*GTX 1080 Ti, Docker (PyTorch 0.4.1, CUDA 9.0 + CuDNN 7.0, Python 3.6)
12 |
13 | ### How to use
14 |
15 | ResNet50 based examples are included. Example scripts are included under ```./scripts/``` directory.
16 | ImageNet data should be included under ```./data/ImageNet/``` with foler named ```train``` and ```val```.
17 |
18 | ```
19 | # To train with BAM (ResNet50 backbone)
20 | python train_imagenet.py --ngpu 4 --workers 20 --arch resnet --depth 50 --epochs 100 --batch-size 256 --lr 0.1 --att-type BAM --prefix RESNET50_IMAGENET_BAM ./data/ImageNet
21 | # To train with CBAM (ResNet50 backbone)
22 | python train_imagenet.py --ngpu 4 --workers 20 --arch resnet --depth 50 --epochs 100 --batch-size 256 --lr 0.1 --att-type CBAM --prefix RESNET50_IMAGENET_CBAM ./data/ImageNet
23 | ```
24 |
25 | ### Resume with checkpoints
26 |
27 | - ResNet50+CBAM (trained for 100 epochs) checkpoint is provided in this [link](https://drive.google.com/file/d/1mvAVvhLR_2XY_bPYxh-SEz4vDmGzSArO/view?usp=sharing). ACC@1=77.622 ACC@5=93.948
28 | - ResNet50+BAM (trained for 90 epochs) checkpoint is provided in this [link](https://drive.google.com/file/d/1auVf70gfL0ol40bvaX5rlbpn9cKIxhAL/view?usp=sharing). ACC@1=76.860 ACC@5=93.416
29 |
30 | For validation, please use the script as follows
31 | ```
32 | python train_imagenet.py --ngpu 4 --workers 20 --arch resnet --depth 50 --att-type CBAM --prefix EVAL --resume $CHECKPOINT_PATH$ --evaluate ./data/ImageNet
33 | ```
34 |
35 | ### Other implementations
36 |
37 | - [MXNet implementation of CBAM with several modifications](https://github.com/bruinxiong/Modified-CBAMnet.mxnet) by [bruinxiong](https://github.com/bruinxiong)
38 |
--------------------------------------------------------------------------------
/scripts/train_imagenet_resnet50_bam.sh:
--------------------------------------------------------------------------------
1 | python train_imagenet.py \
2 | --ngpu 8 \
3 | --workers 20 \
4 | --arch resnet --depth 50 \
5 | --epochs 100 \
6 | --batch-size 256 --lr 0.1 \
7 | --att-type BAM \
8 | --prefix RESNET50_IMAGENET_BAM \
9 | ./data/ImageNet/
10 |
--------------------------------------------------------------------------------
/scripts/train_imagenet_resnet50_cbam.sh:
--------------------------------------------------------------------------------
1 | python train_imagenet.py \
2 | --ngpu 8 \
3 | --workers 20 \
4 | --arch resnet --depth 50 \
5 | --epochs 100 \
6 | --batch-size 256 --lr 0.1 \
7 | --att-type CBAM \
8 | --prefix RESNET50_IMAGENET_CBAM \
9 | ./data/ImageNet/
10 |
--------------------------------------------------------------------------------
/train_imagenet.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 | import time
5 | import random
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.parallel
10 | import torch.backends.cudnn as cudnn
11 | import torch.optim
12 | import torch.utils.data
13 | import torchvision.transforms as transforms
14 | import torchvision.datasets as datasets
15 | import torchvision.models as models
16 | from MODELS.model_resnet import *
17 | from PIL import ImageFile
18 | ImageFile.LOAD_TRUNCATED_IMAGES = True
19 | model_names = sorted(name for name in models.__dict__
20 | if name.islower() and not name.startswith("__")
21 | and callable(models.__dict__[name]))
22 |
23 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
24 | parser.add_argument('data', metavar='DIR',
25 | help='path to dataset')
26 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet',
27 | help='model architecture: ' +
28 | ' | '.join(model_names) +
29 | ' (default: resnet18)')
30 | parser.add_argument('--depth', default=50, type=int, metavar='D',
31 | help='model depth')
32 | parser.add_argument('--ngpu', default=4, type=int, metavar='G',
33 | help='number of gpus to use')
34 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
35 | help='number of data loading workers (default: 4)')
36 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
37 | help='number of total epochs to run')
38 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
39 | help='manual epoch number (useful on restarts)')
40 | parser.add_argument('-b', '--batch-size', default=256, type=int,
41 | metavar='N', help='mini-batch size (default: 256)')
42 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
43 | metavar='LR', help='initial learning rate')
44 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
45 | help='momentum')
46 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
47 | metavar='W', help='weight decay (default: 1e-4)')
48 | parser.add_argument('--print-freq', '-p', default=10, type=int,
49 | metavar='N', help='print frequency (default: 10)')
50 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
51 | help='path to latest checkpoint (default: none)')
52 | parser.add_argument("--seed", type=int, default=1234, metavar='BS', help='input batch size for training (default: 64)')
53 | parser.add_argument("--prefix", type=str, required=True, metavar='PFX', help='prefix for logging & checkpoint saving')
54 | parser.add_argument('--evaluate', dest='evaluate', action='store_true', help='evaluation only')
55 | parser.add_argument('--att-type', type=str, choices=['BAM', 'CBAM'], default=None)
56 | best_prec1 = 0
57 |
58 | if not os.path.exists('./checkpoints'):
59 | os.mkdir('./checkpoints')
60 |
61 | def main():
62 | global args, best_prec1
63 | global viz, train_lot, test_lot
64 | args = parser.parse_args()
65 | print ("args", args)
66 |
67 | torch.manual_seed(args.seed)
68 | torch.cuda.manual_seed_all(args.seed)
69 | random.seed(args.seed)
70 |
71 | # create model
72 | if args.arch == "resnet":
73 | model = ResidualNet( 'ImageNet', args.depth, 1000, args.att_type )
74 |
75 | # define loss function (criterion) and optimizer
76 | criterion = nn.CrossEntropyLoss().cuda()
77 |
78 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
79 | momentum=args.momentum,
80 | weight_decay=args.weight_decay)
81 | model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu)))
82 | #model = torch.nn.DataParallel(model).cuda()
83 | model = model.cuda()
84 | print ("model")
85 | print (model)
86 |
87 | # get the number of model parameters
88 | print('Number of model parameters: {}'.format(
89 | sum([p.data.nelement() for p in model.parameters()])))
90 |
91 | # optionally resume from a checkpoint
92 | if args.resume:
93 | if os.path.isfile(args.resume):
94 | print("=> loading checkpoint '{}'".format(args.resume))
95 | checkpoint = torch.load(args.resume)
96 | args.start_epoch = checkpoint['epoch']
97 | best_prec1 = checkpoint['best_prec1']
98 | model.load_state_dict(checkpoint['state_dict'])
99 | if 'optimizer' in checkpoint:
100 | optimizer.load_state_dict(checkpoint['optimizer'])
101 | print("=> loaded checkpoint '{}' (epoch {})"
102 | .format(args.resume, checkpoint['epoch']))
103 | else:
104 | print("=> no checkpoint found at '{}'".format(args.resume))
105 |
106 |
107 | cudnn.benchmark = True
108 |
109 | # Data loading code
110 | traindir = os.path.join(args.data, 'train')
111 | valdir = os.path.join(args.data, 'val')
112 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
113 | std=[0.229, 0.224, 0.225])
114 |
115 | # import pdb
116 | # pdb.set_trace()
117 | val_loader = torch.utils.data.DataLoader(
118 | datasets.ImageFolder(valdir, transforms.Compose([
119 | transforms.Scale(256),
120 | transforms.CenterCrop(224),
121 | transforms.ToTensor(),
122 | normalize,
123 | ])),
124 | batch_size=args.batch_size, shuffle=False,
125 | num_workers=args.workers, pin_memory=True)
126 | if args.evaluate:
127 | validate(val_loader, model, criterion, 0)
128 | return
129 |
130 | train_dataset = datasets.ImageFolder(
131 | traindir,
132 | transforms.Compose([
133 | transforms.RandomSizedCrop(size0),
134 | transforms.RandomHorizontalFlip(),
135 | transforms.ToTensor(),
136 | normalize,
137 | ]))
138 |
139 | train_sampler = None
140 |
141 | train_loader = torch.utils.data.DataLoader(
142 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
143 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
144 |
145 | for epoch in range(args.start_epoch, args.epochs):
146 | adjust_learning_rate(optimizer, epoch)
147 |
148 | # train for one epoch
149 | train(train_loader, model, criterion, optimizer, epoch)
150 |
151 | # evaluate on validation set
152 | prec1 = validate(val_loader, model, criterion, epoch)
153 |
154 | # remember best prec@1 and save checkpoint
155 | is_best = prec1 > best_prec1
156 | best_prec1 = max(prec1, best_prec1)
157 | save_checkpoint({
158 | 'epoch': epoch + 1,
159 | 'arch': args.arch,
160 | 'state_dict': model.state_dict(),
161 | 'best_prec1': best_prec1,
162 | 'optimizer' : optimizer.state_dict(),
163 | }, is_best, args.prefix)
164 |
165 |
166 | def train(train_loader, model, criterion, optimizer, epoch):
167 | batch_time = AverageMeter()
168 | data_time = AverageMeter()
169 | losses = AverageMeter()
170 | top1 = AverageMeter()
171 | top5 = AverageMeter()
172 |
173 | # switch to train mode
174 | model.train()
175 |
176 | end = time.time()
177 | for i, (input, target) in enumerate(train_loader):
178 | # measure data loading time
179 | data_time.update(time.time() - end)
180 |
181 | target = target.cuda(async=True)
182 | input_var = torch.autograd.Variable(input)
183 | target_var = torch.autograd.Variable(target)
184 |
185 | # compute output
186 | output = model(input_var)
187 | loss = criterion(output, target_var)
188 |
189 | # measure accuracy and record loss
190 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
191 | losses.update(loss.data[0], input.size(0))
192 | top1.update(prec1[0], input.size(0))
193 | top5.update(prec5[0], input.size(0))
194 |
195 | # compute gradient and do SGD step
196 | optimizer.zero_grad()
197 | loss.backward()
198 | optimizer.step()
199 |
200 | # measure elapsed time
201 | batch_time.update(time.time() - end)
202 | end = time.time()
203 |
204 | if i % args.print_freq == 0:
205 | print('Epoch: [{0}][{1}/{2}]\t'
206 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
207 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
208 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
209 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
210 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
211 | epoch, i, len(train_loader), batch_time=batch_time,
212 | data_time=data_time, loss=losses, top1=top1, top5=top5))
213 |
214 | def validate(val_loader, model, criterion, epoch):
215 | batch_time = AverageMeter()
216 | losses = AverageMeter()
217 | top1 = AverageMeter()
218 | top5 = AverageMeter()
219 |
220 | # switch to evaluate mode
221 | model.eval()
222 |
223 | end = time.time()
224 | for i, (input, target) in enumerate(val_loader):
225 | target = target.cuda(async=True)
226 | input_var = torch.autograd.Variable(input, volatile=True)
227 | target_var = torch.autograd.Variable(target, volatile=True)
228 |
229 | # compute output
230 | output = model(input_var)
231 | loss = criterion(output, target_var)
232 |
233 | # measure accuracy and record loss
234 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
235 | losses.update(loss.data[0], input.size(0))
236 | top1.update(prec1[0], input.size(0))
237 | top5.update(prec5[0], input.size(0))
238 |
239 | # measure elapsed time
240 | batch_time.update(time.time() - end)
241 | end = time.time()
242 |
243 | if i % args.print_freq == 0:
244 | print('Test: [{0}/{1}]\t'
245 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
246 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
247 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
248 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
249 | i, len(val_loader), batch_time=batch_time, loss=losses,
250 | top1=top1, top5=top5))
251 |
252 | print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
253 | .format(top1=top1, top5=top5))
254 |
255 | return top1.avg
256 |
257 |
258 | def save_checkpoint(state, is_best, prefix):
259 | filename='./checkpoints/%s_checkpoint.pth.tar'%prefix
260 | torch.save(state, filename)
261 | if is_best:
262 | shutil.copyfile(filename, './checkpoints/%s_model_best.pth.tar'%prefix)
263 |
264 |
265 | class AverageMeter(object):
266 | """Computes and stores the average and current value"""
267 | def __init__(self):
268 | self.reset()
269 |
270 | def reset(self):
271 | self.val = 0
272 | self.avg = 0
273 | self.sum = 0
274 | self.count = 0
275 |
276 | def update(self, val, n=1):
277 | self.val = val
278 | self.sum += val * n
279 | self.count += n
280 | self.avg = self.sum / self.count
281 |
282 |
283 | def adjust_learning_rate(optimizer, epoch):
284 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
285 | lr = args.lr * (0.1 ** (epoch // 30))
286 | for param_group in optimizer.param_groups:
287 | param_group['lr'] = lr
288 |
289 |
290 | def accuracy(output, target, topk=(1,)):
291 | """Computes the precision@k for the specified values of k"""
292 | maxk = max(topk)
293 | batch_size = target.size(0)
294 |
295 | _, pred = output.topk(maxk, 1, True, True)
296 | pred = pred.t()
297 | correct = pred.eq(target.view(1, -1).expand_as(pred))
298 |
299 | res = []
300 | for k in topk:
301 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
302 | res.append(correct_k.mul_(100.0 / batch_size))
303 | return res
304 |
305 |
306 | if __name__ == '__main__':
307 | main()
308 |
--------------------------------------------------------------------------------