├── Models ├── __pycache__ │ ├── conv.cpython-35.pyc │ ├── conv.cpython-36.pyc │ ├── conv.cpython-37.pyc │ ├── senet.cpython-37.pyc │ ├── resnet.cpython-35.pyc │ ├── resnet.cpython-36.pyc │ ├── resnet.cpython-37.pyc │ ├── vggnet.cpython-37.pyc │ ├── attention.cpython-35.pyc │ ├── attention.cpython-36.pyc │ └── attention.cpython-37.pyc ├── conv.py ├── attention.py └── resnet.py ├── LICENSE ├── README.md ├── utils.py └── main.py /Models/__pycache__/conv.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asdf2kr/BAM-CBAM-pytorch/HEAD/Models/__pycache__/conv.cpython-35.pyc -------------------------------------------------------------------------------- /Models/__pycache__/conv.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asdf2kr/BAM-CBAM-pytorch/HEAD/Models/__pycache__/conv.cpython-36.pyc -------------------------------------------------------------------------------- /Models/__pycache__/conv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asdf2kr/BAM-CBAM-pytorch/HEAD/Models/__pycache__/conv.cpython-37.pyc -------------------------------------------------------------------------------- /Models/__pycache__/senet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asdf2kr/BAM-CBAM-pytorch/HEAD/Models/__pycache__/senet.cpython-37.pyc -------------------------------------------------------------------------------- /Models/__pycache__/resnet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asdf2kr/BAM-CBAM-pytorch/HEAD/Models/__pycache__/resnet.cpython-35.pyc -------------------------------------------------------------------------------- /Models/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asdf2kr/BAM-CBAM-pytorch/HEAD/Models/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /Models/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asdf2kr/BAM-CBAM-pytorch/HEAD/Models/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /Models/__pycache__/vggnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asdf2kr/BAM-CBAM-pytorch/HEAD/Models/__pycache__/vggnet.cpython-37.pyc -------------------------------------------------------------------------------- /Models/__pycache__/attention.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asdf2kr/BAM-CBAM-pytorch/HEAD/Models/__pycache__/attention.cpython-35.pyc -------------------------------------------------------------------------------- /Models/__pycache__/attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asdf2kr/BAM-CBAM-pytorch/HEAD/Models/__pycache__/attention.cpython-36.pyc -------------------------------------------------------------------------------- /Models/__pycache__/attention.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/asdf2kr/BAM-CBAM-pytorch/HEAD/Models/__pycache__/attention.cpython-37.pyc -------------------------------------------------------------------------------- /Models/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | def conv1x1(in_channels, out_channels, stride=1): 4 | ''' 1x1 convolution ''' 5 | return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) 6 | 7 | def conv3x3(in_channels, out_channels, stride=1, padding=1, dilation=1): 8 | ''' 3x3 convolution ''' 9 | return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=False) 10 | 11 | def conv7x7(in_channels, out_channels, stride=1, padding=3, dilation=1): 12 | ''' 7x7 convolution ''' 13 | return nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=stride, padding=padding, dilation=dilation, bias=False) 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Gyeongbo Sim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | BAM & CBAM Pytorch 2 | ================== 3 | 4 | [Pytorch](https://pytorch.org/) implementation of BAM and CBAM. 5 | ## BAM & CBAM Pytorch 6 | 7 | This code purpose to evaluate of popular attention model architectures, such as BAM, CBAM on the CIFAR dataset. 8 | 9 | > Park J, Woo S, Lee J Y, Kweon I S. BAM: Bottleneck Attention Module. 2018. [BMVC2018(Oral)](https://arxiv.org/pdf/1807.06514.pdf) 10 | 11 | > Woo S, Park J, Lee J Y, Kweon I S. CBAM: Convolutional Block Attention Module. 2018. [ECCV2018](https://arxiv.org/pdf/1807.06521.pdf) 12 | 13 | #### Architecture 14 | 15 | BAM 16 | ![image](https://user-images.githubusercontent.com/26369382/98519653-693d1300-22b4-11eb-8f29-fd7ff2520ee5.png) 17 | 18 | CBAM 19 | ![image](https://user-images.githubusercontent.com/26369382/98519785-9689c100-22b4-11eb-8bc6-b9fd0445f258.png) 20 | 21 | #### Getting Started 22 | ```bash 23 | $ git clone https://github.com/asdf2kr/BAM-CBAM-pytorch.git 24 | $ cd BAM-CBAM-pytorch 25 | $ python main.py --arch bam (default: bam network based on resnet50) 26 | ``` 27 | 28 | #### Performance 29 | The table below shows models, dataset and performances 30 | 31 | Model | Backbone | Dataset | Top-1 | Top-5 | Size 32 | :----:| :----:| :------:| :----:|:-----:|:----: 33 | ResNet| resnet50 |CIFAR-100 | 78.93% | - | 23.70M 34 | BAM | resnet50 |CIFAR-100 | 79.62% | - | 24.06M 35 | CBAM | resnet50 |CIFAR-100 | 81.02% | - | 26.23M 36 | 37 | #### Reference 38 | [Official PyTorch code](https://github.com/Jongchan/attention-module) 39 | -------------------------------------------------------------------------------- /Models/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from Models.conv import conv1x1, conv3x3, conv7x7 4 | 5 | class BAM(nn.Module): 6 | def __init__(self, in_channel, reduction_ratio, dilation): 7 | super(BAM, self).__init__() 8 | self.hid_channel = in_channel // reduction_ratio 9 | self.dilation = dilation 10 | self.globalAvgPool = nn.AdaptiveAvgPool2d(1) 11 | self.relu = nn.ReLU(inplace=True) 12 | self.sigmoid = nn.Sigmoid() 13 | 14 | self.fc1 = nn.Linear(in_features=in_channel, out_features=self.hid_channel) 15 | self.bn1_1d = nn.BatchNorm1d(self.hid_channel) 16 | self.fc2 = nn.Linear(in_features=self.hid_channel, out_features=in_channel) 17 | self.bn2_1d = nn.BatchNorm1d(in_channel) 18 | 19 | self.conv1 = conv1x1(in_channel, self.hid_channel) 20 | self.bn1_2d = nn.BatchNorm2d(self.hid_channel) 21 | self.conv2 = conv3x3(self.hid_channel, self.hid_channel, stride=1, padding=self.dilation, dilation=self.dilation) 22 | self.bn2_2d = nn.BatchNorm2d(self.hid_channel) 23 | self.conv3 = conv3x3(self.hid_channel, self.hid_channel, stride=1, padding=self.dilation, dilation=self.dilation) 24 | self.bn3_2d = nn.BatchNorm2d(self.hid_channel) 25 | self.conv4 = conv1x1(self.hid_channel, 1) 26 | self.bn4_2d = nn.BatchNorm2d(1) 27 | 28 | def forward(self, x): 29 | # Channel attention 30 | Mc = self.globalAvgPool(x) 31 | Mc = Mc.view(Mc.size(0), -1) 32 | 33 | Mc = self.fc1(Mc) 34 | Mc = self.bn1_1d(Mc) 35 | Mc = self.relu(Mc) 36 | 37 | Mc = self.fc2(Mc) 38 | Mc = self.bn2_1d(Mc) 39 | Mc = self.relu(Mc) 40 | 41 | Mc = Mc.view(Mc.size(0), Mc.size(1), 1, 1) 42 | 43 | # Spatial attention 44 | Ms = self.conv1(x) 45 | Ms = self.bn1_2d(Ms) 46 | Ms = self.relu(Ms) 47 | 48 | Ms = self.conv2(Ms) 49 | Ms = self.bn2_2d(Ms) 50 | Ms = self.relu(Ms) 51 | 52 | Ms = self.conv3(Ms) 53 | Ms = self.bn3_2d(Ms) 54 | Ms = self.relu(Ms) 55 | 56 | Ms = self.conv4(Ms) 57 | Ms = self.bn4_2d(Ms) 58 | Ms = self.relu(Ms) 59 | 60 | Ms = Ms.view(x.size(0), 1, x.size(2), x.size(3)) 61 | Mf = 1 + self.sigmoid(Mc * Ms) 62 | return x * Mf 63 | 64 | #To-do: 65 | class CBAM(nn.Module): 66 | def __init__(self, in_channel, reduction_ratio, dilation=1): 67 | super(CBAM, self).__init__() 68 | self.hid_channel = in_channel // reduction_ratio 69 | self.dilation = dilation 70 | 71 | self.globalAvgPool = nn.AdaptiveAvgPool2d(1) 72 | self.globalMaxPool = nn.AdaptiveMaxPool2d(1) 73 | 74 | # Shared MLP. 75 | self.mlp = nn.Sequential( 76 | nn.Linear(in_features=in_channel, out_features=self.hid_channel), 77 | nn.ReLU(), 78 | nn.Linear(in_features=self.hid_channel, out_features=in_channel) 79 | ) 80 | 81 | self.relu = nn.ReLU(inplace=True) 82 | self.sigmoid = nn.Sigmoid() 83 | 84 | self.conv1 = conv7x7(2, 1, stride=1, dilation=self.dilation) 85 | 86 | def forward(self, x): 87 | ''' Channel attention ''' 88 | avgOut = self.globalAvgPool(x) 89 | avgOut = avgOut.view(avgOut.size(0), -1) 90 | avgOut = self.mlp(avgOut) 91 | 92 | maxOut = self.globalMaxPool(x) 93 | maxOut = maxOut.view(maxOut.size(0), -1) 94 | maxOut = self.mlp(maxOut) 95 | # sigmoid(MLP(AvgPool(F)) + MLP(MaxPool(F))) 96 | Mc = self.sigmoid(avgOut + maxOut) 97 | Mc = Mc.view(Mc.size(0), Mc.size(1), 1, 1) 98 | Mf1 = Mc * x 99 | 100 | ''' Spatial attention. ''' 101 | # sigmoid(conv7x7( [AvgPool(F); MaxPool(F)])) 102 | maxOut = torch.max(Mf1, 1)[0].unsqueeze(1) 103 | avgOut = torch.mean(Mf1, 1).unsqueeze(1) 104 | Ms = torch.cat((maxOut, avgOut), dim=1) 105 | 106 | Ms = self.conv1(Ms) 107 | Ms = self.sigmoid(Ms) 108 | Ms = Ms.view(Ms.size(0), 1, Ms.size(2), Ms.size(3)) 109 | Mf2 = Ms * Mf1 110 | return Mf2 111 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | import torchvision.transforms as transforms 5 | # https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md#download-the-imagenet-dataset 6 | 7 | def prepare_dataloaders(args): 8 | ''' 9 | ImageNET datasets. 10 | pytorch.org/docs/stable/torchvision/datasets.html#imagenet 11 | ''' 12 | data_path = os.path.join(os.getcwd(), 'Datas') 13 | 14 | if args.datasets == 'cifar10': 15 | train_transform = transform=transforms.Compose([ 16 | transforms.RandomCrop(32, padding=4), #padding=4 17 | transforms.RandomHorizontalFlip(), 18 | transforms.ToTensor(), 19 | transforms.Normalize((0.485, 0.456, 0.406), 20 | (0.229, 0.224, 0.225)), 21 | ]) 22 | valid_transform = transform=transforms.Compose([ 23 | # transforms.Resize(32), 24 | transforms.ToTensor(), 25 | transforms.Normalize((0.485, 0.456, 0.406), 26 | (0.229, 0.224, 0.225)), 27 | ]) 28 | train_dataset = torchvision.datasets.CIFAR10(root=data_path, train=True, download=True, transform=train_transform) 29 | valid_dataset = torchvision.datasets.CIFAR10(root=data_path, train=False, download=True, transform=valid_transform) 30 | 31 | elif args.datasets == 'cifar100': 32 | train_transform = transform=transforms.Compose([ 33 | #transforms.Resize(32, padding=4), 34 | # transforms.Resize(256), 35 | transforms.RandomCrop(32, padding=4), 36 | transforms.RandomHorizontalFlip(), 37 | transforms.ToTensor(), 38 | transforms.Normalize((0.485, 0.456, 0.406), 39 | (0.229, 0.224, 0.225)), 40 | ]) 41 | valid_transform = transform=transforms.Compose([ 42 | # transforms.Resize(32), 43 | # transforms.RandomCrop(224), 44 | transforms.ToTensor(), 45 | transforms.Normalize((0.485, 0.456, 0.406), 46 | (0.229, 0.224, 0.225)), 47 | ]) 48 | train_dataset = torchvision.datasets.CIFAR100(root=data_path, train=True, download=True, transform=train_transform) 49 | valid_dataset = torchvision.datasets.CIFAR100(root=data_path, train=False, download=True, transform=valid_transform) 50 | print("args.datasets cifar 100") 51 | elif args.datasets == 'imagenet': 52 | train_transform = transform=transforms.Compose([ 53 | transforms.Resize(256), 54 | transforms.RandomCrop(224), 55 | transforms.RandomHorizontalFlip(), 56 | transforms.ToTensor(), 57 | transforms.Normalize((0.485, 0.456, 0.406), 58 | (0.229, 0.224, 0.225)), 59 | ]) 60 | valid_transform = transform=transforms.Compose([ 61 | transforms.Resize(256), 62 | transforms.CenterCrop(224), 63 | transforms.ToTensor(), 64 | transforms.Normalize((0.485, 0.456, 0.406), 65 | (0.229, 0.224, 0.225)), 66 | ]) 67 | # ImageNet 2012 Classification Dataset. 68 | train_dataset = torchvision.datasets.ImageNet(root=data_path, split='train', download=True, transform=train_transform) 69 | valid_dataset = torchvision.datasets.ImageNet(root=data_path, split='val', download=True, transform=valid_transform) 70 | 71 | 72 | train_loader = torch.utils.data.DataLoader(train_dataset, 73 | batch_size = args.batch, 74 | num_workers = args.workers, 75 | shuffle = True) 76 | 77 | valid_loader = torch.utils.data.DataLoader(valid_dataset, 78 | batch_size = args.batch, 79 | num_workers = args.workers) 80 | 81 | 82 | return train_loader, valid_loader, len(train_dataset), len(valid_dataset) 83 | -------------------------------------------------------------------------------- /Models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from Models.attention import BAM, CBAM 4 | from Models.conv import conv1x1, conv3x3 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] 7 | 8 | 9 | class BasicBlock(nn.Module): 10 | expansion = 1 11 | def __init__(self, in_channels, hid_channels, atte='bam', ratio=16, stride=1, downsample=None): 12 | super(BasicBlock, self).__init__() 13 | self.conv1 = conv3x3(in_channels, hid_channels, stride) 14 | self.bn1 = nn.BatchNorm2d(hid_channels) 15 | self.relu = nn.ReLU(inplace=True) 16 | self.conv2 = conv3x3(hid_channels, hid_channels) 17 | self.bn2 = nn.BatchNorm2d(hid_channels) 18 | self.downsample = downsample 19 | 20 | if atte == 'cbam': 21 | self.atte = CBAM(hid_channels, ratio) 22 | else: 23 | self.atte = None 24 | 25 | def forward(self, x): 26 | residual = x 27 | 28 | out = self.conv1(x) 29 | out = self.bn1(out) 30 | out = self.relu(out) 31 | 32 | out = self.conv2(out) 33 | out = self.bn2(out) 34 | 35 | if self.downsample is not None: 36 | residual = self.downsample(x) 37 | 38 | # CBAM 39 | if not self.atte is None: 40 | out = self.atte(out) 41 | 42 | out += residual 43 | out = self.relu(out) 44 | 45 | return out 46 | 47 | class BottleneckBlock(nn.Module): # bottelneck-block, over the 50 layers. 48 | expansion = 4 49 | def __init__(self, in_channels, hid_channels, atte='bam', ratio=16, stride=1, downsample=None): 50 | super(BottleneckBlock, self).__init__() 51 | self.downsample = downsample 52 | out_channels = hid_channels * self.expansion 53 | self.conv1 = conv1x1(in_channels, hid_channels) 54 | self.bn1 = nn.BatchNorm2d(hid_channels) 55 | 56 | self.conv2 = conv3x3(hid_channels, hid_channels, stride) 57 | self.bn2 = nn.BatchNorm2d(hid_channels) 58 | 59 | self.conv3 = conv1x1(hid_channels, out_channels) 60 | self.bn3 = nn.BatchNorm2d(out_channels) 61 | 62 | self.relu = nn.ReLU(inplace=True) 63 | 64 | if atte == 'cbam': 65 | self.atte = CBAM(out_channels, ratio) 66 | else: 67 | self.atte = None 68 | 69 | def forward(self, x): 70 | residual = x # indentity 71 | out = self.conv1(x) 72 | out = self.bn1(out) 73 | out = self.relu(out) 74 | 75 | out = self.conv2(out) 76 | out = self.bn2(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv3(out) 80 | out = self.bn3(out) 81 | 82 | if self.downsample is not None: 83 | residual = self.downsample(x) 84 | 85 | if not self.atte is None: 86 | out = self.atte(out) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | class ResNet(nn.Module): 94 | ''' 95 | *50-layer 96 | conv1 (output: 112x112) 97 | 7x7, 64, stride 2 98 | conv2 (output: 56x56) 99 | 3x3 max pool, stride 2 100 | [ 1x1, 64 ] 101 | [ 3x3, 64 ] x 3 102 | [ 1x1, 256 ] 103 | cov3 (output: 28x28) 104 | [ 1x1, 128 ] 105 | [ 3x3, 128 ] x 4 106 | [ 1x1, 512 ] 107 | cov4 (output: 14x14) 108 | [ 1x1, 256 ] 109 | [ 3x3, 256 ] x 6 110 | [ 1x1, 1024] 111 | cov5 (output: 28x28) 112 | [ 1x1, 512 ] 113 | [ 3x3, 512 ] x 3 114 | [ 1x1, 2048] 115 | _ (output: 1x1) 116 | average pool, 100-d fc, softmax 117 | FLOPs 3.8x10^9 118 | ''' 119 | ''' 120 | *101-layer 121 | conv1 (output: 112x112) 122 | 7x7, 64, stride 2 123 | conv2 (output: 56x56) 124 | 3x3 max pool, stride 2 125 | [ 1x1, 64 ] 126 | [ 3x3, 64 ] x 3 127 | [ 1x1, 256 ] 128 | cov3 (output: 28x28) 129 | [ 1x1, 128 ] 130 | [ 3x3, 128 ] x 4 131 | [ 1x1, 512 ] 132 | cov4 (output: 14x14) 133 | [ 1x1, 256 ] 134 | [ 3x3, 256 ] x 23 135 | [ 1x1, 1024] 136 | cov5 (output: 28x28) 137 | [ 1x1, 512 ] 138 | [ 3x3, 512 ] x 3 139 | [ 1x1, 2048] 140 | _ (output: 1x1) 141 | average pool, 100-d fc, softmax 142 | FLOPs 7.6x10^9 143 | ''' 144 | def __init__(self, block, layers, num_classes=1000, atte='bam', ratio=16, dilation=4): 145 | super(ResNet, self).__init__() 146 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 147 | 148 | self.layers = layers 149 | self.in_channels = 64 150 | self.atte = atte 151 | self.ratio = ratio 152 | self.dilation = dilation 153 | 154 | if num_classes == 1000: 155 | self.conv1 = nn.Sequential( 156 | nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3, bias=False), 157 | nn.BatchNorm2d(64), 158 | nn.ReLU(inplace=True), 159 | nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 160 | ) 161 | else: 162 | self.conv1 = nn.Sequential( 163 | nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False), 164 | nn.BatchNorm2d(64), 165 | nn.ReLU(inplace=True) 166 | ) 167 | 168 | if self.atte == 'bam': 169 | self.bam1 = BAM(64*block.expansion, self.ratio, self.dilation) 170 | self.bam2 = BAM(128*block.expansion, self.ratio, self.dilation) 171 | self.bam3 = BAM(256*block.expansion, self.ratio, self.dilation) 172 | 173 | self.conv2 = self.get_layers(block, 64, self.layers[0]) 174 | self.conv3 = self.get_layers(block, 128, self.layers[1], stride=2) 175 | self.conv4 = self.get_layers(block, 256, self.layers[2], stride=2) 176 | self.conv5 = self.get_layers(block, 512, self.layers[3], stride=2) 177 | self.avgPool = nn.AdaptiveAvgPool2d((1, 1)) 178 | self.fc = nn.Linear(512 * block.expansion, num_classes) 179 | 180 | torch.nn.init.kaiming_normal_(self.fc.weight) 181 | for m in self.state_dict(): 182 | if isinstance(m, nn.Conv2d): 183 | torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 184 | elif isinstance(m, nn.BatchNorm2d): 185 | nn.init.constant_(m.weight, 1) 186 | nn.init.constant_(m.bias, 0) 187 | torch.nn.init.kaiming_normal_(self.fc.weight) 188 | for m in self.state_dict(): 189 | if isinstance(m, nn.Conv2d): 190 | torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 191 | elif isinstance(m, nn.BatchNorm2d): 192 | nn.init.constant_(m.weight, 1) 193 | nn.init.constant_(m.bias, 0) 194 | 195 | def get_layers(self, block, hid_channels, n_layers, stride=1): 196 | downsample = None 197 | if stride != 1 or self.in_channels != hid_channels * block.expansion: 198 | downsample = nn.Sequential( 199 | conv1x1(self.in_channels, hid_channels * block.expansion, stride), 200 | nn.BatchNorm2d(hid_channels * block.expansion), 201 | ) 202 | layers = [] 203 | layers.append(block(self.in_channels, hid_channels, self.atte, self.ratio, stride, downsample)) 204 | self.in_channels = hid_channels * block.expansion 205 | 206 | for _ in range(1, n_layers): 207 | layers.append(block(self.in_channels, hid_channels, self.atte, self.ratio)) 208 | return nn.Sequential(*layers) 209 | 210 | def forward(self, x): 211 | ''' 212 | Example tensor shape based on resnet101 213 | ''' 214 | 215 | x = self.conv1(x) 216 | 217 | x = self.conv2(x) 218 | if self.atte == 'bam': 219 | x = self.bam1(x) 220 | 221 | x = self.conv3(x) 222 | if self.atte == 'bam': 223 | x = self.bam2(x) 224 | 225 | x = self.conv4(x) 226 | if self.atte == 'bam': 227 | x = self.bam3(x) 228 | 229 | x = self.conv5(x) 230 | x = self.avgPool(x) 231 | x = x.view(x.size(0), -1) 232 | x = self.fc(x) 233 | return x 234 | 235 | def resnet18(**kwargs): 236 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 237 | 238 | def resnet34(**kwargs): 239 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 240 | 241 | def resnet50(**kwargs): 242 | return ResNet(BottleneckBlock, [3, 4, 6, 3], **kwargs) 243 | 244 | def resnet101(**kwargs): 245 | ''' ResNet-101 Model''' 246 | return ResNet(BottleneckBlock, [3, 4, 23, 3], **kwargs) 247 | 248 | def resnet152(**kwargs): 249 | return ResNet(BottleneckBlock, [3, 8, 36, 3], **kwargs) 250 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import shutil 4 | import parser 5 | import argparse 6 | 7 | import torch 8 | import torch.nn 9 | import torch.optim 10 | import torchvision.models as models 11 | 12 | import Models.resnet as resnet 13 | 14 | from utils import prepare_dataloaders 15 | from tqdm import tqdm 16 | ''' 17 | reference: 18 | pytorch, torchvision 19 | conda install -c conda-forge torchvision 20 | ''' 21 | def count_parameters(model): 22 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 23 | 24 | def main(): 25 | ''' Main function ''' 26 | parser = argparse.ArgumentParser(description='Implement image classification on ImageNet datset using pytorch') 27 | parser.add_argument('--arch', default='bam', type=str, help='Attention Model (bam, cbam)') 28 | parser.add_argument('--backbone', default='resnet50', type=str, help='backbone classification model (resnet(18, 34, 50, 101, 152)') 29 | parser.add_argument('--epoch', default=1, type=int, help='start epoch') 30 | parser.add_argument('--n_epochs', default=350, type=int, help='numeber of total epochs to run') 31 | parser.add_argument('--batch', default=256, type=int, help='mini batch size (default: 1024)') 32 | parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') 33 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 34 | parser.add_argument('--save_directory', default='trained.chkpt', type=str, help='path to latest checkpoint') 35 | parser.add_argument('--workers', default=0, type=int, help='num_workers') 36 | parser.add_argument('--resume', default=False, type=bool, help='resume') 37 | parser.add_argument('--datasets', default='CIFAR100', type=str, help='classification dataset (CIFAR10, CIFAR100, ImageNet)') 38 | parser.add_argument('--weight_decay', default=5e-4, type=float, help='weight_decay') 39 | parser.add_argument('--save', default='trained', type=str, help='trained.chkpt') 40 | parser.add_argument('--save_multi', default='trained_multi', type=str, help='trained_multi.chkpt') 41 | parser.add_argument('--evaluate', default=False, type=bool, help='evaluate') 42 | parser.add_argument('--reduction_ratio', default=16, type=int, help='reduction_ratio') 43 | parser.add_argument('--dilation_value', default=4, type=int, help='reduction_ratio') 44 | args = parser.parse_args() 45 | args.arch = args.arch.lower() 46 | args.backbone = args.backbone.lower() 47 | args.datasets = args.datasets.lower() 48 | 49 | if not os.path.isdir('checkpoints'): 50 | os.mkdir('checkpoints') 51 | # To-do: Write a code relating to seed. 52 | 53 | # use gpu or multi-gpu or not. 54 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 55 | use_multi_gpu = torch.cuda.device_count() > 1 56 | print('[Info] device:{} use_multi_gpu:{}'.format(device, use_multi_gpu)) 57 | 58 | if args.datasets == 'cifar10': 59 | num_classes = 10 60 | elif args.datasets == 'cifar100': 61 | num_classes = 100 62 | elif args.datasets == 'imagenet': 63 | num_classes = 1000 64 | 65 | # load the data. 66 | print('[Info] Load the data.') 67 | train_loader, valid_loader, train_size, valid_size = prepare_dataloaders(args) 68 | 69 | # load the model. 70 | print('[Info] Load the model.') 71 | 72 | if args.backbone == 'resnet18': 73 | model = resnet.resnet18(num_classes=num_classes, atte=args.arch, ratio=args.reduction_ratio, dilation = args.dilation_value) 74 | elif args.backbone == 'resnet34': 75 | model = resnet.resnet34(num_classes=num_classes, atte=args.arch, ratio=args.reduction_ratio, dilation = args.dilation_value) 76 | elif args.backbone == 'resnet50': 77 | model = resnet.resnet50(num_classes=num_classes, atte=args.arch, ratio=args.reduction_ratio, dilation = args.dilation_value) 78 | elif args.backbone == 'resnet101': 79 | model = resnet.resnet101(num_classes=num_classes, atte=args.arch, ratio=args.reduction_ratio, dilation = args.dilation_value) 80 | elif args.backbone == 'resnet152': 81 | model = resnet.resnet152(num_classes=num_classes, atte=args.arch, ratio=args.reduction_ratio, dilation = args.dilation_value) 82 | 83 | 84 | model = model.to(device) 85 | if use_multi_gpu : model = torch.nn.DataParallel(model) 86 | print('[Info] Total parameters {} '.format(count_parameters(model))) 87 | # define loss function. 88 | criterion = torch.nn.CrossEntropyLoss().to(device) 89 | 90 | # define optimizer 91 | optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 92 | 93 | if args.resume: 94 | # Load the checkpoint. 95 | print('[Info] Loading checkpoint.') 96 | if torch.cuda.device_count() > 1: 97 | checkpoint = load_checkpoint(args.save) 98 | else: 99 | checkpoint = load_checkpoint(args.save) 100 | 101 | backbone = checkpoint['backbone'] 102 | args.epoch = checkpoint['epoch'] 103 | state_dict = checkpoint['state_dict'] 104 | model.load_state_dict(state_dict) 105 | print('[Info] epoch {} backbone {}'.format(args.epoch, backbone)) 106 | 107 | # run evaluate. 108 | if args.evaluate: 109 | _ = run_epoch(model, 'valid', [args.epoch, args.epoch], criterion, optimizer, valid_loader, valid_size, device) 110 | return 111 | 112 | # run train. 113 | best_acc1 = 0. 114 | for e in range(args.epoch, args.n_epochs + 1): 115 | adjust_learning_rate(optimizer, e, args) 116 | 117 | # train for one epoch 118 | _ = run_epoch(model, 'train', [e, args.n_epochs], criterion, optimizer, train_loader, train_size, device) 119 | 120 | # evaluate on validation set 121 | with torch.no_grad(): 122 | acc1 = run_epoch(model, 'valid', [e, args.n_epochs], criterion, optimizer, valid_loader, valid_size, device) 123 | 124 | # Save checkpoint. 125 | is_best = acc1 > best_acc1 126 | best_acc1 = max(acc1, best_acc1) 127 | save_checkpoint({ 128 | 'epoch': e, 129 | 'backbone': args.backbone, 130 | 'state_dict': model.state_dict(), 131 | 'best_acc1': best_acc1, 132 | 'optimizer': optimizer.state_dict(), 133 | }, is_best, args.save) 134 | 135 | if use_multi_gpu: 136 | save_checkpoint({ 137 | 'epoch': e, 138 | 'backbone': args.backbone, 139 | 'state_dict': model.module.state_dict(), 140 | 'best_acc1': best_acc1, 141 | 'optimizer': optimizer.state_dict(), 142 | }, is_best, args.save_multi) 143 | 144 | print('[Info] acc1 {} best@acc1 {}'.format(acc1, best_acc1)) 145 | 146 | def run_epoch(model, mode, epoch, criterion, optimizer, data_loader, dataset_size, device): 147 | if mode == 'train': 148 | model.train() 149 | else: 150 | model.eval() 151 | 152 | losses = AverageMeter() 153 | top1 = AverageMeter() 154 | top5 = AverageMeter() 155 | 156 | start = time.time() 157 | tq = tqdm(data_loader, desc=' - (' + mode + ') ', leave=False) 158 | for data, target in tq: 159 | # prepare data 160 | data, target = data.to(device), target.to(device) 161 | 162 | # forward 163 | output = model(data) 164 | loss = criterion(output, target) 165 | 166 | # measure accuracy and record loss 167 | prec1, prec5 = accuracy(output.data, target, topk=(1,5)) 168 | losses.update(loss.item(), data.size(0)) 169 | top1.update(prec1[0], data.size(0)) 170 | top5.update(prec5[0], data.size(0)) 171 | 172 | if mode == 'train': 173 | # compte gradient and do SGD step 174 | optimizer.zero_grad() 175 | loss.backward() 176 | optimizer.step() 177 | 178 | tq.set_description(' - ({}) [ epoch: {}/{} loss: {:.3f}/{:.3f} ] '.format(mode, epoch[0], epoch[1], losses.val, losses.avg)) 179 | #tqdm.write 180 | tqdm.write(' - ({}) [ epoch: {}\ttop@1: {:.3f}\ttop@5: {:.3f}\tloss: {:.3f}\ttime: {:.3f}]'.format(mode, epoch, top1.avg, top5.avg, losses.avg, (time.time() - start)/60.)) 181 | 182 | 183 | return top1.avg 184 | 185 | def save_checkpoint(state, is_best, prefix): 186 | filename='checkpoints/{}_checkpoint.chkpt'.format(prefix) 187 | torch.save(state, filename) 188 | if is_best: 189 | shutil.copyfile(filename, 'checkpoints/{}_best.chkpt'.format(prefix)) 190 | print(' - [Info] The checkpoint file has been updated.') 191 | 192 | def load_checkpoint(prefix): 193 | filename='checkpoints/{}_checkpoint.chkpt'.format(prefix) 194 | return torch.load(filename) 195 | 196 | class AverageMeter(object): 197 | '''Computes and stores the average and current value''' 198 | def __init__(self): 199 | self.reset() 200 | def reset(self): 201 | self.val = 0 202 | self.avg = 0 203 | self.sum = 0 204 | self.count = 0 205 | def update(self, val, n=1): 206 | self.val = val 207 | self.sum += val * n 208 | self.count += n 209 | self.avg = self.sum / self.count 210 | 211 | def count_parameters(model): 212 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 213 | 214 | def adjust_learning_rate(optimizer, epoch, args): 215 | lr = args.lr * (0.1 ** (epoch // 100)) 216 | for param_group in optimizer.param_groups: 217 | param_group['lr'] = lr 218 | 219 | def accuracy(output, target, topk=(1,)): 220 | """Computes the accuracy over the k top predictions for the specified values of k""" 221 | with torch.no_grad(): 222 | maxk = max(topk) 223 | bsz = target.size(0) 224 | ''' 225 | https://pytorch.org/docs/stable/torch.html#torch.topk 226 | torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor) 227 | ''' 228 | _, pred = output.topk(maxk, 1, largest=True, sorted=True) 229 | pred = pred.t() 230 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 231 | 232 | res = [] 233 | for k in topk: 234 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 235 | res.append(correct_k.mul_(100.0 / bsz)) 236 | return res 237 | 238 | if __name__ == '__main__': 239 | main() 240 | --------------------------------------------------------------------------------