├── .gitignore
├── figures
├── attentions.png
└── training_curves.png
├── networks
├── __pycache__
│ └── __init__.cpython-38.pyc
├── cifar
│ ├── __pycache__
│ │ ├── block.cpython-38.pyc
│ │ ├── resnet.cpython-38.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── preresnet.cpython-38.pyc
│ │ ├── mobilenetv2.cpython-38.pyc
│ │ └── wideresnet.cpython-38.pyc
│ ├── __init__.py
│ ├── mobilenetv2.py
│ ├── wideresnet.py
│ ├── resnet.py
│ ├── preresnet.py
│ └── block.py
├── imagenet
│ ├── __pycache__
│ │ ├── resnet.cpython-38.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ └── mobilenetv2.cpython-38.pyc
│ ├── __init__.py
│ ├── mobilenetv2.py
│ └── resnet.py
├── attentions
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── se_module.cpython-38.pyc
│ │ └── simam_module.cpython-38.pyc
│ ├── __init__.py
│ ├── se_module.py
│ ├── simam_module.py
│ └── cbam_module.py
└── __init__.py
├── mmdetection
├── configs
│ ├── mask_rcnn
│ │ ├── mask_rcnn_r101simam_fpn_1x_coco.py
│ │ └── mask_rcnn_r50simam_fpn_1x_coco.py
│ ├── faster_rcnn
│ │ ├── faster_rcnn_r101simam_fpn_1x_coco.py
│ │ └── faster_rcnn_r50simam_fpn_1x_coco.py
│ └── _base_
│ │ ├── schedules
│ │ └── schedule_1x_lr0.01.py
│ │ └── models
│ │ ├── faster_rcnn_r50simam_fpn.py
│ │ └── mask_rcnn_r50simam_fpn.py
└── mmdet
│ └── models
│ └── backbones
│ ├── attentions
│ └── simam_module.py
│ └── resnet_simam.py
├── checkpoint.py
├── util.py
├── README.md
├── main_cifar.py
└── main_imagenet.py
/.gitignore:
--------------------------------------------------------------------------------
1 | pretrained
2 | SimAM-previous.rar
--------------------------------------------------------------------------------
/figures/attentions.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/figures/attentions.png
--------------------------------------------------------------------------------
/figures/training_curves.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/figures/training_curves.png
--------------------------------------------------------------------------------
/networks/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/networks/cifar/__pycache__/block.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/cifar/__pycache__/block.cpython-38.pyc
--------------------------------------------------------------------------------
/networks/cifar/__pycache__/resnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/cifar/__pycache__/resnet.cpython-38.pyc
--------------------------------------------------------------------------------
/networks/cifar/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/cifar/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/networks/cifar/__pycache__/preresnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/cifar/__pycache__/preresnet.cpython-38.pyc
--------------------------------------------------------------------------------
/networks/imagenet/__pycache__/resnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/imagenet/__pycache__/resnet.cpython-38.pyc
--------------------------------------------------------------------------------
/networks/cifar/__pycache__/mobilenetv2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/cifar/__pycache__/mobilenetv2.cpython-38.pyc
--------------------------------------------------------------------------------
/networks/cifar/__pycache__/wideresnet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/cifar/__pycache__/wideresnet.cpython-38.pyc
--------------------------------------------------------------------------------
/networks/imagenet/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/imagenet/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/networks/attentions/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/attentions/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/networks/attentions/__pycache__/se_module.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/attentions/__pycache__/se_module.cpython-38.pyc
--------------------------------------------------------------------------------
/networks/imagenet/__pycache__/mobilenetv2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/imagenet/__pycache__/mobilenetv2.cpython-38.pyc
--------------------------------------------------------------------------------
/networks/attentions/__pycache__/simam_module.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/attentions/__pycache__/simam_module.cpython-38.pyc
--------------------------------------------------------------------------------
/networks/attentions/__init__.py:
--------------------------------------------------------------------------------
1 | from ..import find_module_using_name
2 |
3 |
4 |
5 | def get_attention_module(attention_type="none"):
6 |
7 | return find_module_using_name(attention_type.lower())
--------------------------------------------------------------------------------
/mmdetection/configs/mask_rcnn/mask_rcnn_r101simam_fpn_1x_coco.py:
--------------------------------------------------------------------------------
1 | _base_ = './mask_rcnn_r50simam_fpn_1x_coco.py'
2 | model = dict(pretrained='checkpoints/simam-net/resnet101.pth.tar', backbone=dict(depth=101))
--------------------------------------------------------------------------------
/mmdetection/configs/faster_rcnn/faster_rcnn_r101simam_fpn_1x_coco.py:
--------------------------------------------------------------------------------
1 | _base_ = './faster_rcnn_r50simam_fpn_1x_coco.py'
2 | model = dict(pretrained='checkpoints/simam-net/resnet101.pth.tar', backbone=dict(depth=101))
3 |
--------------------------------------------------------------------------------
/mmdetection/configs/mask_rcnn/mask_rcnn_r50simam_fpn_1x_coco.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/mask_rcnn_r50simam_fpn.py',
3 | '../_base_/datasets/coco_instance.py',
4 | '../_base_/schedules/schedule_1x_lr0.01.py', '../_base_/default_runtime.py'
5 | ]
6 |
7 | find_unused_parameters=True
--------------------------------------------------------------------------------
/mmdetection/configs/faster_rcnn/faster_rcnn_r50simam_fpn_1x_coco.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/faster_rcnn_r50simam_fpn.py',
3 | '../_base_/datasets/coco_detection.py',
4 | '../_base_/schedules/schedule_1x_lr0.01.py', '../_base_/default_runtime.py'
5 | ]
6 |
7 | find_unused_parameters=True
8 |
--------------------------------------------------------------------------------
/mmdetection/configs/_base_/schedules/schedule_1x_lr0.01.py:
--------------------------------------------------------------------------------
1 | # optimizer
2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
3 | optimizer_config = dict(grad_clip=None)
4 | # learning policy
5 | lr_config = dict(
6 | policy='step',
7 | warmup='linear',
8 | warmup_iters=500,
9 | warmup_ratio=0.001,
10 | step=[8, 11])
11 | total_epochs = 12
12 |
--------------------------------------------------------------------------------
/networks/attentions/se_module.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | class se_module(nn.Module):
4 | def __init__(self, channel, reduction=16):
5 | super(se_module, self).__init__()
6 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
7 | self.fc = nn.Sequential(
8 | nn.Linear(channel, int(channel // reduction), bias=False),
9 | nn.ReLU(inplace=True),
10 | nn.Linear(int(channel // reduction), channel, bias=False),
11 | nn.Sigmoid()
12 | )
13 |
14 | @staticmethod
15 | def get_module_name():
16 | return "se"
17 |
18 | def forward(self, x):
19 | b, c, _, _ = x.size()
20 | y = self.avg_pool(x).view(b, c)
21 | y = self.fc(y).view(b, c, 1, 1)
22 | return x * y
--------------------------------------------------------------------------------
/networks/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import torch.nn as nn
3 |
4 | def find_module_using_name(module_name):
5 |
6 | if module_name == "none":
7 | return None
8 |
9 | module_filename = "networks.attentions." + module_name + "_module"
10 | modellib = importlib.import_module(module_filename)
11 |
12 | module = None
13 | target_model_name = module_name + '_module'
14 |
15 | for name, cls in modellib.__dict__.items():
16 | if name.lower() == target_model_name.lower() and issubclass(cls, nn.Module):
17 | module = cls
18 |
19 | if module is None:
20 | print("In %s.py, there should be a subclass of nn.Module with class name that matches %s in lowercase." % (module_filename, target_model_name))
21 | exit(0)
22 |
23 | return module
--------------------------------------------------------------------------------
/networks/attentions/simam_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class simam_module(torch.nn.Module):
6 | def __init__(self, channels = None, e_lambda = 1e-4):
7 | super(simam_module, self).__init__()
8 |
9 | self.activaton = nn.Sigmoid()
10 | self.e_lambda = e_lambda
11 |
12 | def __repr__(self):
13 | s = self.__class__.__name__ + '('
14 | s += ('lambda=%f)' % self.e_lambda)
15 | return s
16 |
17 | @staticmethod
18 | def get_module_name():
19 | return "simam"
20 |
21 | def forward(self, x):
22 |
23 | b, c, h, w = x.size()
24 |
25 | n = w * h - 1
26 |
27 | x_minus_mu_square = (x - x.mean(dim=[2,3], keepdim=True)).pow(2)
28 | y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2,3], keepdim=True) / n + self.e_lambda)) + 0.5
29 |
30 | return x * self.activaton(y)
--------------------------------------------------------------------------------
/mmdetection/mmdet/models/backbones/attentions/simam_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class simam_module(torch.nn.Module):
6 | def __init__(self, channels = None, e_lambda = 1e-4):
7 | super(simam_module, self).__init__()
8 |
9 | self.activaton = nn.Sigmoid()
10 | self.e_lambda = e_lambda
11 |
12 | def __repr__(self):
13 | s = self.__class__.__name__ + '('
14 | s += ('lambda=%f)' % self.e_lambda)
15 | return s
16 |
17 | @staticmethod
18 | def get_module_name():
19 | return "simam"
20 |
21 | def forward(self, x):
22 |
23 | b, c, h, w = x.size()
24 |
25 | n = w * h - 1
26 |
27 | x_minus_mu_square = (x - x.mean(dim=[2,3], keepdim=True)).pow(2)
28 | y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2,3], keepdim=True) / n + self.e_lambda)) + 0.5
29 |
30 | return x * self.activaton(y)
--------------------------------------------------------------------------------
/networks/imagenet/__init__.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from ..attentions import get_attention_module
3 | from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152, resnext50_32x4d
4 | from .mobilenetv2 import mobilenet_v2
5 | from .resnet import resnet50d
6 |
7 | model_dict = {
8 | "mobilenet_v2": mobilenet_v2,
9 | "resnet18": resnet18,
10 | "resnet34": resnet34,
11 | "resnet50": resnet50,
12 | "resnet101": resnet101,
13 | "resnet152": resnet152,
14 | "resnet50d": resnet50d,
15 | "resnext50_32x4d": resnext50_32x4d
16 | }
17 |
18 |
19 | def create_net(args):
20 | net = None
21 |
22 | attention_module = get_attention_module(args.attention_type)
23 |
24 | # srm does not have any input parameters
25 | if args.attention_type == "se" or args.attention_type == "cbam":
26 | attention_module = functools.partial(attention_module, reduction=args.attention_param)
27 | elif args.attention_type == "simam":
28 | attention_module = functools.partial(attention_module, e_lambda=args.attention_param)
29 |
30 | kwargs = {}
31 | kwargs["num_classes"] = 1000
32 | kwargs["attention_module"] = attention_module
33 |
34 | net = model_dict[args.arch.lower()](**kwargs)
35 |
36 | return net
--------------------------------------------------------------------------------
/checkpoint.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import shutil
4 |
5 | def save_checkpoint(state, is_best, epoch, save_path='./'):
6 | print("=> saving checkpoint '{}'".format(epoch))
7 | torch.save(state, os.path.join(save_path, 'checkpoint.pth.tar'))
8 | if(epoch % 10 == 0):
9 | torch.save(state, os.path.join(save_path, 'checkpoint_%03d.pth.tar' % epoch))
10 | if is_best:
11 | if epoch >= 90:
12 | shutil.copyfile(os.path.join(save_path, 'checkpoint.pth.tar'),
13 | os.path.join(save_path, 'model_best_in_100_epochs.pth.tar'))
14 | else:
15 | shutil.copyfile(os.path.join(save_path, 'checkpoint.pth.tar'),
16 | os.path.join(save_path, 'model_best_in_090_epochs.pth.tar'))
17 |
18 |
19 | def load_checkpoint(args, model, optimizer=None, verbose=True):
20 |
21 | checkpoint = torch.load(args.resume)
22 |
23 | start_epoch = 0
24 | best_acc = 0
25 |
26 | if "epoch" in checkpoint:
27 | start_epoch = checkpoint['epoch']
28 |
29 | if "best_acc" in checkpoint:
30 | best_acc = checkpoint['best_acc']
31 |
32 | model.load_state_dict(checkpoint['state_dict'], False)
33 |
34 | if optimizer is not None and "optimizer" in checkpoint:
35 | optimizer.load_state_dict(checkpoint['optimizer'])
36 |
37 | for state in optimizer.state.values():
38 | for k, v in state.items():
39 | if isinstance(v, torch.Tensor):
40 | state[k] = v.to(args.device)
41 |
42 | if verbose:
43 | print("=> loading checkpoint '{}' (epoch {})"
44 | .format(args.resume, start_epoch))
45 |
46 | return model, optimizer, best_acc, start_epoch
--------------------------------------------------------------------------------
/networks/cifar/__init__.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from ..attentions import get_attention_module
3 | from .mobilenetv2 import MobileNetV2Wrapper
4 | from .resnet import ResNet20, ResNet32, ResNet56, ResNet110, ResNet164
5 | from .preresnet import PreResNet20, PreResNet32, PreResNet56, PreResNet110, PreResNet164
6 | from .wideresnet import WideResNet28x10, WideResNet40x10
7 | from .block import BasicBlock, BottleNect, PreBasicBlock, PreBottleNect, InvertedResidualBlock, WideBasicBlock
8 |
9 |
10 | model_dict = {
11 | "resnet20": ResNet20,
12 | "resnet32": ResNet32,
13 | "resnet56": ResNet56,
14 | 'resnet110': ResNet110,
15 | "resnet164": ResNet164,
16 | "preresnet20": PreResNet20,
17 | "preresnet32": PreResNet32,
18 | "preresnet56": PreResNet56,
19 | 'preresnet110': PreResNet110,
20 | "preresnet164": PreResNet164,
21 | "wideresnet28x10": WideResNet28x10,
22 | "wideresnet40x10": WideResNet40x10,
23 | "mobilenetv2": MobileNetV2Wrapper,
24 | }
25 |
26 | def get_block(block_type="basic"):
27 |
28 | block_type = block_type.lower()
29 |
30 | if block_type == "basic":
31 | b = BasicBlock
32 | elif block_type == "bottlenect":
33 | b = BottleNect
34 | elif block_type == "prebasic":
35 | b = PreBasicBlock
36 | elif block_type == "prebottlenect":
37 | b = PreBottleNect
38 | elif block_type == "ivrd":
39 | b = InvertedResidualBlock
40 | elif block_type == "widebasic":
41 | b = WideBasicBlock
42 | else:
43 | raise NotImplementedError('block [%s] is not found for dataset [%s]' % block_type)
44 | return b
45 |
46 |
47 | def create_net(args):
48 | net = None
49 |
50 | block_module = get_block(args.block_type)
51 | attention_module = get_attention_module(args.attention_type)
52 |
53 | if args.attention_type == "se" or args.attention_type == "cbam":
54 | attention_module = functools.partial(attention_module, reduction=args.attention_param)
55 | elif args.attention_type == "simam":
56 | attention_module = functools.partial(attention_module, e_lambda=args.attention_param)
57 |
58 | net = model_dict[args.arch.lower()](
59 | num_class = args.num_class,
60 | block = block_module,
61 | attention_module = attention_module
62 | )
63 |
64 | return net
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import fnmatch
4 | import torch
5 |
6 |
7 | class AverageMeter(object):
8 | """Computes and stores the average and current value"""
9 | def __init__(self, name, fmt=':f'):
10 | self.name = name
11 | self.fmt = fmt
12 | self.reset()
13 |
14 | def reset(self):
15 | self.val = 0
16 | self.avg = 0
17 | self.sum = 0
18 | self.count = 0
19 |
20 | def update(self, val, n=1):
21 | self.val = val
22 | self.sum += val * n
23 | self.count += n
24 | self.avg = self.sum / self.count
25 |
26 | def __str__(self):
27 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
28 | return fmtstr.format(**self.__dict__)
29 |
30 |
31 | class ProgressMeter(object):
32 | def __init__(self, num_batches, meters, prefix=""):
33 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
34 | self.meters = meters
35 | self.prefix = prefix
36 |
37 | def get_message(self, batch):
38 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
39 | entries += [str(meter) for meter in self.meters]
40 | return ('\t').join(entries)
41 |
42 | def _get_batch_fmtstr(self, num_batches):
43 | num_digits = len(str(num_batches // 1))
44 | fmt = "{:" + str(num_digits) + "d}"
45 | return "[" + fmt + "/" + fmt.format(num_batches) + "]"
46 |
47 |
48 | def accuracy(output, target, topk=(1,)):
49 | """Computes the accuracy over the k top predictions for the specified values of k"""
50 | with torch.no_grad():
51 | maxk = max(topk)
52 | batch_size = target.size(0)
53 |
54 | _, pred = output.topk(maxk, 1, True, True)
55 | pred = pred.t()
56 | correct = pred.eq(target.view(1, -1).expand_as(pred))
57 |
58 | res = []
59 | for k in topk:
60 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
61 | res.append(correct_k.mul_(100.0 / batch_size))
62 | return res
63 |
64 |
65 | def parse_gpus(gpu_ids):
66 | gpus = gpu_ids.split(',')
67 | gpu_ids = []
68 | for g in gpus:
69 | g_int = int(g)
70 | if g_int >= 0:
71 | gpu_ids.append(g_int)
72 | if not gpu_ids:
73 | return None
74 | return gpu_ids
75 |
--------------------------------------------------------------------------------
/networks/cifar/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | '''MobileNetV2 in PyTorch.
2 |
3 | See the paper "Inverted Residuals and Linear Bottlenecks:
4 | Mobile Networks for Classification, Detection and Segmentation" for more details.
5 | '''
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from .block import InvertedResidualBlock
10 |
11 | class MobileNetV2(nn.Module):
12 | # (expansion, out_planes, num_blocks, stride)
13 | cfg = [(1, 16, 1, 1),
14 | (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10
15 | (6, 32, 3, 2),
16 | (6, 64, 4, 2),
17 | (6, 96, 3, 1),
18 | (6, 160, 3, 2),
19 | (6, 320, 1, 1)]
20 |
21 | def __init__(self, block, num_blocks=0, num_class=10):
22 | super(MobileNetV2, self).__init__()
23 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10
24 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
25 | self.bn1 = nn.BatchNorm2d(32)
26 | self.layers = self._make_layers(block, in_planes=32)
27 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)
28 | self.bn2 = nn.BatchNorm2d(1280)
29 | self.linear = nn.Linear(1280, num_class)
30 |
31 | def _make_layers(self, block, in_planes):
32 | layers = []
33 | for expansion, out_planes, num_blocks, stride in self.cfg:
34 | strides = [stride] + [1]*(num_blocks-1)
35 | for stride in strides:
36 | layers.append(block(in_planes, out_planes, expansion, stride))
37 | in_planes = out_planes
38 | return nn.Sequential(*layers)
39 |
40 | def forward(self, x):
41 | out = F.relu(self.bn1(self.conv1(x)))
42 | out = self.layers(out)
43 | out = F.relu(self.bn2(self.conv2(out)))
44 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10
45 | out = F.avg_pool2d(out, 4)
46 | out = out.view(out.size(0), -1)
47 | out = self.linear(out)
48 | return out
49 |
50 |
51 | def MobileNetV2Wrapper(num_class=10, block=None, attention_module=None):
52 |
53 | b = lambda in_planes, out_planes, expansion, stride: \
54 | InvertedResidualBlock(in_planes, out_planes, expansion, stride, attention_module=attention_module)
55 |
56 | return MobileNetV2(b, num_blocks=0, num_class=num_class)
57 |
58 |
--------------------------------------------------------------------------------
/networks/cifar/wideresnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import functools
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from .block import WideBasicBlock
6 |
7 |
8 | class NetworkBlock(nn.Module):
9 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
10 | super(NetworkBlock, self).__init__()
11 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
12 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
13 | layers = []
14 | for i in range(nb_layers):
15 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
16 | return nn.Sequential(*layers)
17 | def forward(self, x):
18 | return self.layer(x)
19 |
20 | class WideResNet(nn.Module):
21 | def __init__(self, block, depth, widen_factor=1, dropRate=0.0, num_class=10):
22 | super(WideResNet, self).__init__()
23 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
24 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
25 | n = (depth - 4) // 6
26 | # block = WideBasicBlock
27 | # 1st conv before any network block
28 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False)
29 | # 1st block
30 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
31 | # 2nd block
32 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
33 | # 3rd block
34 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
35 | # global average pooling and classifier
36 | self.bn1 = nn.BatchNorm2d(nChannels[3])
37 | self.relu = nn.ReLU(inplace=True)
38 | self.fc = nn.Linear(nChannels[3], num_class)
39 | self.fc.bias.data.zero_()
40 | # print(self.fc.bias.data)
41 | self.nChannels = nChannels[3]
42 |
43 | for m in self.modules():
44 | if isinstance(m, nn.Conv2d):
45 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
46 | m.weight.data.normal_(0, math.sqrt(2. / n))
47 | elif isinstance(m, nn.BatchNorm2d):
48 | m.weight.data.fill_(1)
49 | m.bias.data.zero_()
50 | # elif isinstance(m, nn.Linear):
51 | # if hasattr(m, "bias"):
52 | # m.bias.data.zero_()
53 |
54 | def forward(self, x):
55 | out = self.conv1(x)
56 | out = self.block1(out)
57 | out = self.block2(out)
58 | out = self.block3(out)
59 | out = self.relu(self.bn1(out))
60 | out = F.avg_pool2d(out, 8)
61 | out = out.view(-1, self.nChannels)
62 | return self.fc(out)
63 |
64 |
65 | def WideResNetWrapper(depth, widen_factor, dropRate=0, num_class=10, attention_module=None):
66 |
67 | b = lambda in_planes, planes, stride, dropRate: \
68 | WideBasicBlock(in_planes, planes, stride, dropRate, attention_module=attention_module)
69 |
70 | return WideResNet(b, depth, widen_factor, dropRate, num_class=num_class)
71 |
72 |
73 | def WideResNet28x10(num_class=10, block=None, attention_module=None):
74 |
75 | return WideResNetWrapper(
76 | depth = 28,
77 | widen_factor = 10,
78 | dropRate = 0.3,
79 | num_class = num_class,
80 | attention_module = attention_module)
81 |
82 |
83 | def WideResNet40x10(num_class=10, block=None, attention_module=None):
84 |
85 | return WideResNetWrapper(
86 | depth = 40,
87 | widen_factor = 10,
88 | dropRate = 0.3,
89 | num_class = num_class,
90 | attention_module = attention_module)
--------------------------------------------------------------------------------
/networks/cifar/resnet.py:
--------------------------------------------------------------------------------
1 | """PyTorch implementation of ResNet
2 |
3 | ResNet modifications written by Bichen Wu and Alvin Wan, based
4 | off of ResNet implementation by Kuang Liu.
5 |
6 | Reference:
7 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
8 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
9 | """
10 | import functools
11 | import torch.nn as nn
12 | from .block import BasicBlock, BottleNect
13 |
14 | class ResNet(nn.Module):
15 | def __init__(self, block, num_blocks, num_class=10):
16 | super(ResNet, self).__init__()
17 |
18 | self.num_class = num_class
19 | self.in_channels = num_filters = 16
20 |
21 | self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False)
22 | self.bn1 = nn.BatchNorm2d(self.in_channels)
23 |
24 | self.relu = nn.ReLU(inplace=True)
25 | self.avgpool = nn.AdaptiveAvgPool2d(1)
26 |
27 | self.layer1 = self._make_layer(block, self.in_channels, num_blocks[0], stride=1)
28 | self.layer2 = self._make_layer(block, int(num_filters*2), num_blocks[1], stride=2)
29 | self.layer3 = self._make_layer(block, int(num_filters*4), num_blocks[2], stride=2)
30 | self.linear = nn.Linear(int(num_filters*4*block(16,16,1).EXPANSION), num_class)
31 |
32 | def _make_layer(self, block, ou_channels, num_blocks, stride):
33 | strides = [stride] + [1]*(num_blocks-1)
34 | layers = []
35 | for stride in strides:
36 | layers.append(block(self.in_channels, ou_channels, stride))
37 | self.in_channels = int(ou_channels * block(16,16,1).EXPANSION)
38 | return nn.Sequential(*layers)
39 |
40 | def forward(self, x):
41 | out = self.relu(self.bn1(self.conv1(x)))
42 | out = self.layer1(out)
43 | out = self.layer2(out)
44 | out = self.layer3(out)
45 | out = self.avgpool(out)
46 | out = out.view(out.size(0), -1)
47 | out = self.linear(out)
48 | return out
49 |
50 |
51 | def ResNetWrapper(num_blocks, num_class=10, block=None, attention_module=None):
52 |
53 | b = lambda in_planes, planes, stride: \
54 | block(in_planes, planes, stride, attention_module=attention_module)
55 |
56 | return ResNet(b, num_blocks, num_class=num_class)
57 |
58 |
59 | def ResNet20(num_class=10, block=None, attention_module=None):
60 | return ResNetWrapper(
61 | num_blocks = [3, 3, 3],
62 | num_class = num_class,
63 | block = block,
64 | attention_module = attention_module)
65 |
66 | def ResNet32(num_class=10, block=None, attention_module=None):
67 | return ResNetWrapper(
68 | num_blocks = [5, 5, 5],
69 | num_class = num_class,
70 | block = block,
71 | attention_module = attention_module)
72 |
73 |
74 | def ResNet56(num_class=10, block=None, attention_module=None):
75 |
76 | if block == BasicBlock:
77 | n_blocks = [9, 9, 9]
78 | elif block == BottleNect:
79 | n_blocks = [6, 6, 6]
80 |
81 | return ResNetWrapper(
82 | num_blocks = n_blocks,
83 | num_class = num_class,
84 | block = block,
85 | attention_module = attention_module)
86 |
87 | def ResNet110(num_class=10, block=None, attention_module=None):
88 |
89 | if block == BasicBlock:
90 | n_blocks = [18, 18, 18]
91 | elif block == BottleNect:
92 | n_blocks = [12, 12, 12]
93 |
94 | return ResNetWrapper(
95 | num_blocks = n_blocks,
96 | num_class = num_class,
97 | block = block,
98 | attention_module = attention_module)
99 |
100 |
101 | def ResNet164(num_class=10, block=None, attention_module=None):
102 |
103 | if block == BasicBlock:
104 | n_blocks = [27, 27, 27]
105 | elif block == BottleNect:
106 | n_blocks = [18, 18, 18]
107 |
108 | return ResNetWrapper(
109 | num_blocks = n_blocks,
110 | num_class = num_class,
111 | block = block,
112 | attention_module = attention_module)
--------------------------------------------------------------------------------
/mmdetection/configs/_base_/models/faster_rcnn_r50simam_fpn.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='FasterRCNN',
3 | pretrained='checkpoints/simam-net/resnet50.pth.tar',
4 | backbone=dict(
5 | type='ResNetAM',
6 | depth=50,
7 | frozen_stages=1,
8 | norm_eval=True,
9 | attention_type="simam",
10 | attention_param=0.1),
11 | neck=dict(
12 | type='FPN',
13 | in_channels=[256, 512, 1024, 2048],
14 | out_channels=256,
15 | num_outs=5),
16 | rpn_head=dict(
17 | type='RPNHead',
18 | in_channels=256,
19 | feat_channels=256,
20 | anchor_generator=dict(
21 | type='AnchorGenerator',
22 | scales=[8],
23 | ratios=[0.5, 1.0, 2.0],
24 | strides=[4, 8, 16, 32, 64]),
25 | bbox_coder=dict(
26 | type='DeltaXYWHBBoxCoder',
27 | target_means=[.0, .0, .0, .0],
28 | target_stds=[1.0, 1.0, 1.0, 1.0]),
29 | loss_cls=dict(
30 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
31 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
32 | roi_head=dict(
33 | type='StandardRoIHead',
34 | bbox_roi_extractor=dict(
35 | type='SingleRoIExtractor',
36 | roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
37 | out_channels=256,
38 | featmap_strides=[4, 8, 16, 32]),
39 | bbox_head=dict(
40 | type='Shared2FCBBoxHead',
41 | in_channels=256,
42 | fc_out_channels=1024,
43 | roi_feat_size=7,
44 | num_classes=80,
45 | bbox_coder=dict(
46 | type='DeltaXYWHBBoxCoder',
47 | target_means=[0., 0., 0., 0.],
48 | target_stds=[0.1, 0.1, 0.2, 0.2]),
49 | reg_class_agnostic=False,
50 | loss_cls=dict(
51 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
52 | loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
53 | # model training and testing settings
54 | train_cfg=dict(
55 | rpn=dict(
56 | assigner=dict(
57 | type='MaxIoUAssigner',
58 | pos_iou_thr=0.7,
59 | neg_iou_thr=0.3,
60 | min_pos_iou=0.3,
61 | match_low_quality=True,
62 | ignore_iof_thr=-1),
63 | sampler=dict(
64 | type='RandomSampler',
65 | num=256,
66 | pos_fraction=0.5,
67 | neg_pos_ub=-1,
68 | add_gt_as_proposals=False),
69 | allowed_border=-1,
70 | pos_weight=-1,
71 | debug=False),
72 | rpn_proposal=dict(
73 | nms_across_levels=False,
74 | nms_pre=2000,
75 | nms_post=1000,
76 | max_num=1000,
77 | nms_thr=0.7,
78 | min_bbox_size=0),
79 | rcnn=dict(
80 | assigner=dict(
81 | type='MaxIoUAssigner',
82 | pos_iou_thr=0.5,
83 | neg_iou_thr=0.5,
84 | min_pos_iou=0.5,
85 | match_low_quality=False,
86 | ignore_iof_thr=-1),
87 | sampler=dict(
88 | type='RandomSampler',
89 | num=512,
90 | pos_fraction=0.25,
91 | neg_pos_ub=-1,
92 | add_gt_as_proposals=True),
93 | pos_weight=-1,
94 | debug=False)),
95 | test_cfg=dict(
96 | rpn=dict(
97 | nms_across_levels=False,
98 | nms_pre=1000,
99 | nms_post=1000,
100 | max_num=1000,
101 | nms_thr=0.7,
102 | min_bbox_size=0),
103 | rcnn=dict(
104 | score_thr=0.05,
105 | nms=dict(type='nms', iou_threshold=0.5),
106 | max_per_img=100)
107 | # soft-nms is also supported for rcnn testing
108 | # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05)
109 | ))
110 |
--------------------------------------------------------------------------------
/networks/cifar/preresnet.py:
--------------------------------------------------------------------------------
1 | """PyTorch implementation of ResNet
2 |
3 | ResNet modifications written by Bichen Wu and Alvin Wan, based
4 | off of ResNet implementation by Kuang Liu.
5 |
6 | Reference:
7 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
8 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
9 | """
10 | import functools
11 | import torch.nn as nn
12 | from .block import PreBasicBlock, PreBottleNect
13 |
14 | class PreResNet(nn.Module):
15 | def __init__(self, block, num_blocks, num_class=10):
16 | super(PreResNet, self).__init__()
17 |
18 | self.num_class = num_class
19 | self.in_channels = num_base_filters = 16
20 |
21 | self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False)
22 |
23 | self.layer1 = self._make_layer(block, self.in_channels, num_blocks[0], stride=1)
24 | self.layer2 = self._make_layer(block, int(num_base_filters*2), num_blocks[1], stride=2)
25 | self.layer3 = self._make_layer(block, int(num_base_filters*4), num_blocks[2], stride=2)
26 |
27 | self.bn = nn.BatchNorm2d(int(num_base_filters*4*block(16,16,1).EXPANSION))
28 |
29 | self.linear = nn.Linear(int(num_base_filters*4*block(16,16,1).EXPANSION), num_class)
30 |
31 | self.relu = nn.ReLU(inplace=True)
32 |
33 | self.avgpool = nn.AdaptiveAvgPool2d(1)
34 |
35 | def _make_layer(self, block, ou_channels, num_blocks, stride):
36 | strides = [stride] + [1]*(num_blocks-1)
37 | layers = []
38 | for stride in strides:
39 | layers.append(block(self.in_channels, ou_channels, stride))
40 | self.in_channels = int(ou_channels * block(16,16,1).EXPANSION)
41 | return nn.Sequential(*layers)
42 |
43 | def forward(self, x):
44 | out = self.conv1(x)
45 | out = self.layer1(out)
46 | out = self.layer2(out)
47 | out = self.layer3(out)
48 | out = self.bn(out)
49 | out = self.relu(out)
50 |
51 | out = self.avgpool(out)
52 | out = out.view(out.size(0), -1)
53 | out = self.linear(out)
54 | return out
55 |
56 |
57 | def PreResNetWrapper(num_blocks, num_class=10, block=None, attention_module=None):
58 |
59 | b = lambda in_planes, planes, stride: \
60 | block(in_planes, planes, stride, attention_module=attention_module)
61 |
62 | return PreResNet(b, num_blocks, num_class=num_class)
63 |
64 |
65 | def PreResNet20(num_class=10, block=None, attention_module=None):
66 | return PreResNetWrapper(
67 | num_blocks = [3, 3, 3],
68 | num_class = num_class,
69 | block = block,
70 | attention_module = attention_module)
71 |
72 |
73 | def PreResNet32(num_class=10, block=None, attention_module=None):
74 | return PreResNetWrapper(
75 | num_blocks = [5, 5, 5],
76 | num_class = num_class,
77 | block = block,
78 | attention_module = attention_module)
79 |
80 |
81 | def PreResNet56(num_class=10, block=None, attention_module=None):
82 |
83 | if block == PreBasicBlock:
84 | n_blocks = [9, 9, 9]
85 | elif block == PreBottleNect:
86 | n_blocks = [6, 6, 6]
87 |
88 | return PreResNetWrapper(
89 | num_blocks = n_blocks,
90 | num_class = num_class,
91 | block = block,
92 | attention_module = attention_module)
93 |
94 | def PreResNet110(num_class=10, block=None, attention_module=None):
95 |
96 | if block == PreBasicBlock:
97 | n_blocks = [18, 18, 18]
98 | elif block == PreBottleNect:
99 | n_blocks = [12, 12, 12]
100 |
101 | return PreResNetWrapper(
102 | num_blocks = n_blocks,
103 | num_class = num_class,
104 | block = block,
105 | attention_module = attention_module)
106 |
107 |
108 | def PreResNet164(num_class=10, block=None, attention_module=None):
109 |
110 | if block == PreBasicBlock:
111 | n_blocks = [27, 27, 27]
112 | elif block == PreBottleNect:
113 | n_blocks = [18, 18, 18]
114 |
115 | return PreResNetWrapper(
116 | num_blocks = n_blocks,
117 | num_class = num_class,
118 | block = block,
119 | attention_module = attention_module)
--------------------------------------------------------------------------------
/networks/attentions/cbam_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | class BasicConv(nn.Module):
7 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
8 | super(BasicConv, self).__init__()
9 | self.out_channels = out_planes
10 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
11 | self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
12 | self.relu = nn.ReLU() if relu else None
13 |
14 | def forward(self, x):
15 | x = self.conv(x)
16 | if self.bn is not None:
17 | x = self.bn(x)
18 | if self.relu is not None:
19 | x = self.relu(x)
20 | return x
21 |
22 | class Flatten(nn.Module):
23 | def forward(self, x):
24 | return x.view(x.size(0), -1)
25 |
26 | class ChannelGate(nn.Module):
27 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
28 | super(ChannelGate, self).__init__()
29 | self.gate_channels = gate_channels
30 | self.mlp = nn.Sequential(
31 | Flatten(),
32 | nn.Linear(gate_channels, int(gate_channels // reduction_ratio)),
33 | nn.ReLU(),
34 | nn.Linear(int(gate_channels // reduction_ratio), gate_channels)
35 | )
36 | self.pool_types = pool_types
37 | def forward(self, x):
38 | channel_att_sum = None
39 | for pool_type in self.pool_types:
40 | if pool_type=='avg':
41 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
42 | channel_att_raw = self.mlp( avg_pool )
43 | elif pool_type=='max':
44 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
45 | channel_att_raw = self.mlp( max_pool )
46 | elif pool_type=='lp':
47 | lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
48 | channel_att_raw = self.mlp( lp_pool )
49 | elif pool_type=='lse':
50 | # LSE pool only
51 | lse_pool = logsumexp_2d(x)
52 | channel_att_raw = self.mlp( lse_pool )
53 |
54 | if channel_att_sum is None:
55 | channel_att_sum = channel_att_raw
56 | else:
57 | channel_att_sum = channel_att_sum + channel_att_raw
58 |
59 | scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
60 | return x * scale
61 |
62 | def logsumexp_2d(tensor):
63 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
64 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
65 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
66 | return outputs
67 |
68 | class ChannelPool(nn.Module):
69 | def forward(self, x):
70 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
71 |
72 | class SpatialGate(nn.Module):
73 | def __init__(self):
74 | super(SpatialGate, self).__init__()
75 | kernel_size = 7
76 | self.compress = ChannelPool()
77 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=int((kernel_size-1) // 2), relu=False)
78 | def forward(self, x):
79 | x_compress = self.compress(x)
80 | x_out = self.spatial(x_compress)
81 | scale = F.sigmoid(x_out) # broadcasting
82 | return x * scale
83 |
84 | class cbam_module(nn.Module):
85 | def __init__(self, gate_channels, reduction=16, pool_types=['avg', 'max'], no_spatial=False):
86 | super(cbam_module, self).__init__()
87 | self.ChannelGate = ChannelGate(gate_channels, reduction, pool_types)
88 | self.no_spatial=no_spatial
89 | if not no_spatial:
90 | self.SpatialGate = SpatialGate()
91 |
92 | @staticmethod
93 | def get_module_name():
94 | return "cbam"
95 |
96 | def forward(self, x):
97 | x_out = self.ChannelGate(x)
98 | if not self.no_spatial:
99 | x_out = self.SpatialGate(x_out)
100 | return x_out
--------------------------------------------------------------------------------
/mmdetection/configs/_base_/models/mask_rcnn_r50simam_fpn.py:
--------------------------------------------------------------------------------
1 | # model settings
2 | model = dict(
3 | type='MaskRCNN',
4 | pretrained='checkpoints/simam-net/resnet50.pth.tar',
5 | backbone=dict(
6 | type='ResNetAM',
7 | depth=50,
8 | frozen_stages=1,
9 | norm_eval=True,
10 | attention_type="simam",
11 | attention_param=0.1),
12 | neck=dict(
13 | type='FPN',
14 | in_channels=[256, 512, 1024, 2048],
15 | out_channels=256,
16 | num_outs=5),
17 | rpn_head=dict(
18 | type='RPNHead',
19 | in_channels=256,
20 | feat_channels=256,
21 | anchor_generator=dict(
22 | type='AnchorGenerator',
23 | scales=[8],
24 | ratios=[0.5, 1.0, 2.0],
25 | strides=[4, 8, 16, 32, 64]),
26 | bbox_coder=dict(
27 | type='DeltaXYWHBBoxCoder',
28 | target_means=[.0, .0, .0, .0],
29 | target_stds=[1.0, 1.0, 1.0, 1.0]),
30 | loss_cls=dict(
31 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
32 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
33 | roi_head=dict(
34 | type='StandardRoIHead',
35 | bbox_roi_extractor=dict(
36 | type='SingleRoIExtractor',
37 | roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
38 | out_channels=256,
39 | featmap_strides=[4, 8, 16, 32]),
40 | bbox_head=dict(
41 | type='Shared2FCBBoxHead',
42 | in_channels=256,
43 | fc_out_channels=1024,
44 | roi_feat_size=7,
45 | num_classes=80,
46 | bbox_coder=dict(
47 | type='DeltaXYWHBBoxCoder',
48 | target_means=[0., 0., 0., 0.],
49 | target_stds=[0.1, 0.1, 0.2, 0.2]),
50 | reg_class_agnostic=False,
51 | loss_cls=dict(
52 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
53 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
54 | mask_roi_extractor=dict(
55 | type='SingleRoIExtractor',
56 | roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
57 | out_channels=256,
58 | featmap_strides=[4, 8, 16, 32]),
59 | mask_head=dict(
60 | type='FCNMaskHead',
61 | num_convs=4,
62 | in_channels=256,
63 | conv_out_channels=256,
64 | num_classes=80,
65 | loss_mask=dict(
66 | type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
67 | # model training and testing settings
68 | train_cfg=dict(
69 | rpn=dict(
70 | assigner=dict(
71 | type='MaxIoUAssigner',
72 | pos_iou_thr=0.7,
73 | neg_iou_thr=0.3,
74 | min_pos_iou=0.3,
75 | match_low_quality=True,
76 | ignore_iof_thr=-1),
77 | sampler=dict(
78 | type='RandomSampler',
79 | num=256,
80 | pos_fraction=0.5,
81 | neg_pos_ub=-1,
82 | add_gt_as_proposals=False),
83 | allowed_border=-1,
84 | pos_weight=-1,
85 | debug=False),
86 | rpn_proposal=dict(
87 | nms_across_levels=False,
88 | nms_pre=2000,
89 | nms_post=1000,
90 | max_num=1000,
91 | nms_thr=0.7,
92 | min_bbox_size=0),
93 | rcnn=dict(
94 | assigner=dict(
95 | type='MaxIoUAssigner',
96 | pos_iou_thr=0.5,
97 | neg_iou_thr=0.5,
98 | min_pos_iou=0.5,
99 | match_low_quality=True,
100 | ignore_iof_thr=-1),
101 | sampler=dict(
102 | type='RandomSampler',
103 | num=512,
104 | pos_fraction=0.25,
105 | neg_pos_ub=-1,
106 | add_gt_as_proposals=True),
107 | mask_size=28,
108 | pos_weight=-1,
109 | debug=False)),
110 | test_cfg=dict(
111 | rpn=dict(
112 | nms_across_levels=False,
113 | nms_pre=1000,
114 | nms_post=1000,
115 | max_num=1000,
116 | nms_thr=0.7,
117 | min_bbox_size=0),
118 | rcnn=dict(
119 | score_thr=0.05,
120 | nms=dict(type='nms', iou_threshold=0.5),
121 | max_per_img=100,
122 | mask_thr_binary=0.5)))
123 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## SimAM: A Simple, Parameter-Free Attention Module for Convolutional Neural Networks
2 | [**Lingxiao Yang**](https://zjjconan.github.io/), [Ru-Yuan Zhang](https://ruyuanzhang.github.io/), [Lida Li](https://github.com/lld533), [Xiaohua Xie](http://cse.sysu.edu.cn/content/2478)
3 |
4 | Abstract
5 | ----------
6 | In this paper, we propose a conceptually simple but very effective attention module for Convolutional Neural Networks (ConvNets). In contrast to existing channel-wise and spatial-wise attention modules, our module instead infers 3-D attention weights for the feature map in a layer without adding parameters to the original networks. Specifically, we base on some well-known neuroscience theories and propose to optimize an energy function to find the importance of each neuron. We further derive a fast closed-form solution for the energy function, and show that the solution can be implemented in less than ten lines of code. Another advantage of the module is that most of the operators are selected based on the solution to the defined energy function, avoiding too many efforts for structure tuning. Quantitative evaluations on various visual tasks demonstrate that the proposed module is flexible and effective to improve the representation ability of many ConvNets. Our code is available at [Pytorch-SimAM](https://github.com/ZjjConan/SimAM).
7 |
8 | --------------------------------------------------
9 |
10 | Our environments and toolkits
11 | -----------
12 |
13 | - OS: Ubuntu 18.04.5
14 | - CUDA: 11.0
15 | - Python: 3.8.3
16 | - Toolkit: PyTorch 1.8.0
17 | - GPU: Quadro RTX 8000 (4x)
18 | - [thop](https://github.com/Lyken17/pytorch-OpCounter)
19 |
20 |
21 | Module
22 | ------
23 |
24 | Our goal is to infer a 3-D attention weights (Figure (c)) with a given feature map, which is very different to previous works as shown in Figure (a) and (b).
25 |
26 |
27 |
28 |
29 |
30 | **SimAM (A pytorch-like implementation).** Detail of implementations, including the module and the network, can be found in ``networks`` in this repository.
31 |
32 |
33 | ```python
34 | class SimAM(nn.Module):
35 | # X: input feature [N, C, H, W]
36 | # lambda: coefficient λ in Eqn (5)
37 | def forward (X, lambda):
38 | # spatial size
39 | n = X.shape[2] * X.shape[3] - 1
40 | # square of (t - u)
41 | d = (X - X.mean(dim=[2,3])).pow(2)
42 | # d.sum() / n is channel variance
43 | v = d.sum(dim=[2,3]) / n
44 | # E_inv groups all importance of X
45 | E_inv = d / (4 * (v + lambda)) + 0.5
46 | # return attended features
47 | return X * sigmoid(E_inv)
48 | ```
49 |
50 | Training and Validation Curves
51 | ----------
52 |
53 |
54 |
55 |
56 |
57 | Experiments
58 | ----------
59 |
60 | ### Training and Evaluation
61 |
62 | The following commands train models on ImageNet from scratch with 4 gpus.
63 |
64 |
65 | ```
66 | # Training from scratch
67 |
68 | python main_imagenet.py {the path of ImageNet} --gpu 0,1,2,3 --epochs 100 -j 20 -a resnet18
69 |
70 | python main_imagenet.py {the path of ImageNet} --gpu 0,1,2,3 --epochs 100 -j 20 -a resnet18
71 | --attention_type simam --attention_param 0.1
72 |
73 | python main_imagenet.py {the path of ImageNet} --gpu 0,1,2,3 --epochs 150 -j 20 -a mobilenet_v2
74 | --attention_type simam --attention_param 0.1 --lr .05 --cos_lr --wd 4e-5
75 | ```
76 |
77 | ```
78 | # Evaluating the trained model
79 |
80 | python main_imagenet.py {the path of ImageNet} --gpu 0,1,2,3 -j 20 -a resnet18 -e
81 | --resume {the path of pretrained .pth}
82 | ```
83 |
84 | ### ImageNet
85 |
86 | All the following models can be download from **[BaiduYunPan](https://pan.baidu.com/s/1i_imf4Ny4U9SkD5e9nhCJw)** (extract code: **25tp**) and **[Google Drive](https://drive.google.com/drive/folders/1rRT0UCPeRLPdTCJvv43hvAnGnS49nIWn?usp=sharing).**
87 |
88 | |Model |Parameters |FLOPs |Top-1(%) |Top-5(%)|
89 | |:---: |:----: |:---: |:------: |:------:|
90 | |SimAM-R18 |11.69 M |1.82 G |71.31 |89.88 |
91 | |SimAM-R34 |21.80 M |3.67 G |74.49 |92.02 |
92 | |SimAM-R50 |25.56 M |4.11 G |77.45 |93.66 |
93 | |SimAM-R101 |44.55 M |7.83 G |78.65 |94.11 |
94 | |SimAM-RX50 (32x4d) |25.03 M |4.26 G |78.00 |93.93 |
95 | |SimAM-MV2 |3.50 M |0.31 G |72.36 |90.74 |
96 |
97 | ### COCO Evaluation
98 | We use [mmdetection](https://github.com/open-mmlab/mmdetection) to train Faster RCNN and Mask RCNN for object detection and instance segmentation. If you want to run the following models, please firstly install `mmdetection` with their guide. And then put all `.py` in mmdetection of this repository to the corresponding folders. All the following models can be download from **[BaiduYunPan](https://pan.baidu.com/s/1NtMgu09vv0tEhb2PsXCsQQ)** (extract code: **ysrz**) and **[Google Drive](https://drive.google.com/drive/folders/1F8W3MY32crU6jUeV2sgc_4AQwqt_MvAp?usp=sharing).**
99 |
100 | #### Detection with Faster RCNN (FR for short) and Mask RCNN (MR for short)
101 |
102 | |Model |AP |AP_50 |AP_75|AP_S |AP_M |AP_L |
103 | |:----: |:----: |:---: |:--: |:----: |:---: |:--: |
104 | |FR-SimAM-R50 |39.2 |60.7 |40.8 |22.8 |43.0 |50.6 |
105 | |FR-SimAM-R101 |41.2 |62.4 |45.0 |24.0 |45.6 |52.8 |
106 | |MR-SimAM-R50 |39.8 |61.0 |43.4 |23.1 |43.7 |51.4 |
107 | |MR-SimAM-R101 |41.8 |62.8 |46.0 |24.8 |46.2 |53.9 |
108 |
109 |
110 | #### Instance Segmentation with Mask RCNN (MR for short)
111 |
112 | |Model |AP |AP_50 |AP_75|AP_S |AP_M |AP_L |
113 | |:----: |:----: |:---: |:--: |:----: |:---: |:--: |
114 | |MR-SimAM-R50 |36.0 |57.9 |38.2 |19.1 |39.7 |48.6 |
115 | |MR-SimAM-R101 |37.6 |59.5 |40.1 |20.5 |41.5 |50.8 |
116 |
117 | --------------------------------------------------------------------
118 |
119 |
120 | Citation
121 | --------
122 | If you find SimAM useful in your research, please consider citing:
123 |
124 | @InProceedings{pmlr-v139-yang21o,
125 | title = {SimAM: A Simple, Parameter-Free Attention Module for Convolutional Neural Networks},
126 | author = {Yang, Lingxiao and Zhang, Ru-Yuan and Li, Lida and Xie, Xiaohua},
127 | booktitle = {Proceedings of the 38th International Conference on Machine Learning},
128 | pages = {11863--11874},
129 | year = {2021},
130 | editor = {Meila, Marina and Zhang, Tong},
131 | volume = {139},
132 | series = {Proceedings of Machine Learning Research},
133 | month = {18--24 Jul},
134 | publisher = {PMLR},
135 | pdf = {http://proceedings.mlr.press/v139/yang21o/yang21o.pdf},
136 | url = {http://proceedings.mlr.press/v139/yang21o.html}
137 | }
138 |
139 | ## Contact Information
140 |
141 | If you have any suggestion or question, you can contact us by: lingxiao.yang717@gmail.com. Thanks for your attention!
142 |
--------------------------------------------------------------------------------
/networks/imagenet/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from torch import nn
3 | from torch import Tensor
4 | # from .utils import load_state_dict_from_url
5 | from typing import Callable, Any, Optional, List
6 |
7 |
8 | __all__ = ['MobileNetV2', 'mobilenet_v2']
9 |
10 |
11 | # model_urls = {
12 | # 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
13 | # }
14 |
15 |
16 | def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
17 | """
18 | This function is taken from the original tf repo.
19 | It ensures that all layers have a channel number that is divisible by 8
20 | It can be seen here:
21 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
22 | :param v:
23 | :param divisor:
24 | :param min_value:
25 | :return:
26 | """
27 | if min_value is None:
28 | min_value = divisor
29 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
30 | # Make sure that round down does not go down by more than 10%.
31 | if new_v < 0.9 * v:
32 | new_v += divisor
33 | return new_v
34 |
35 |
36 | class ConvBNActivation(nn.Sequential):
37 | def __init__(
38 | self,
39 | in_planes: int,
40 | out_planes: int,
41 | kernel_size: int = 3,
42 | stride: int = 1,
43 | groups: int = 1,
44 | norm_layer: Optional[Callable[..., nn.Module]] = None,
45 | activation_layer: Optional[Callable[..., nn.Module]] = None,
46 | attention_module: Optional[Callable[..., nn.Module]] = None,
47 | ) -> None:
48 | padding = (kernel_size - 1) // 2
49 | if norm_layer is None:
50 | norm_layer = nn.BatchNorm2d
51 | if activation_layer is None:
52 | activation_layer = nn.ReLU6
53 | if attention_module is not None:
54 | if type(attention_module) == functools.partial:
55 | module_name = attention_module.func.get_module_name()
56 | else:
57 | module_name = attention_module.get_module_name()
58 |
59 | if module_name == "simam":
60 | super(ConvBNReLU, self).__init__(
61 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
62 | attention_module(out_planes),
63 | norm_layer(out_planes),
64 | activation_layer(inplace=True)
65 | )
66 | else:
67 | super(ConvBNReLU, self).__init__(
68 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
69 | norm_layer(out_planes),
70 | activation_layer(inplace=True)
71 | )
72 | else:
73 | super(ConvBNReLU, self).__init__(
74 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
75 | norm_layer(out_planes),
76 | activation_layer(inplace=True)
77 | )
78 |
79 |
80 | # necessary for backwards compatibility
81 | ConvBNReLU = ConvBNActivation
82 |
83 |
84 | class InvertedResidual(nn.Module):
85 | def __init__(
86 | self,
87 | inp: int,
88 | oup: int,
89 | stride: int,
90 | expand_ratio: int,
91 | norm_layer: Optional[Callable[..., nn.Module]] = None,
92 | attention_module: Optional[Callable[..., nn.Module]] = None
93 | ) -> None:
94 | super(InvertedResidual, self).__init__()
95 | self.stride = stride
96 | assert stride in [1, 2]
97 |
98 | if norm_layer is None:
99 | norm_layer = nn.BatchNorm2d
100 |
101 | hidden_dim = int(round(inp * expand_ratio))
102 | self.use_res_connect = self.stride == 1 and inp == oup
103 |
104 | layers: List[nn.Module] = []
105 | if expand_ratio != 1:
106 | # pw
107 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
108 | layers.extend([
109 | # dw
110 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer, attention_module=attention_module),
111 | # pw-linear
112 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
113 | norm_layer(oup),
114 | ])
115 |
116 | if attention_module is not None:
117 | if type(attention_module) == functools.partial:
118 | module_name = attention_module.func.get_module_name()
119 | else:
120 | module_name = attention_module.get_module_name()
121 |
122 | if module_name != "simam":
123 | # print(attention_module)
124 | layers.append(attention_module(oup))
125 |
126 | self.conv = nn.Sequential(*layers)
127 |
128 | def forward(self, x: Tensor) -> Tensor:
129 | if self.use_res_connect:
130 | return x + self.conv(x)
131 | else:
132 | return self.conv(x)
133 |
134 |
135 | class MobileNetV2(nn.Module):
136 | def __init__(
137 | self,
138 | num_classes: int = 1000,
139 | width_mult: float = 1.0,
140 | inverted_residual_setting: Optional[List[List[int]]] = None,
141 | round_nearest: int = 8,
142 | block: Optional[Callable[..., nn.Module]] = None,
143 | norm_layer: Optional[Callable[..., nn.Module]] = None,
144 | attention_module: Optional[Callable[..., nn.Module]] = None
145 | ) -> None:
146 | """
147 | MobileNet V2 main class
148 |
149 | Args:
150 | num_classes (int): Number of classes
151 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
152 | inverted_residual_setting: Network structure
153 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number
154 | Set to 1 to turn off rounding
155 | block: Module specifying inverted residual building block for mobilenet
156 | norm_layer: Module specifying the normalization layer to use
157 |
158 | attention_module: Module specifying the attention layer to use
159 | """
160 | super(MobileNetV2, self).__init__()
161 |
162 | if block is None:
163 | block = InvertedResidual
164 |
165 | if norm_layer is None:
166 | norm_layer = nn.BatchNorm2d
167 |
168 | input_channel = 32
169 | last_channel = 1280
170 |
171 | if inverted_residual_setting is None:
172 | inverted_residual_setting = [
173 | # t, c, n, s
174 | [1, 16, 1, 1],
175 | [6, 24, 2, 2],
176 | [6, 32, 3, 2],
177 | [6, 64, 4, 2],
178 | [6, 96, 3, 1],
179 | [6, 160, 3, 2],
180 | [6, 320, 1, 1],
181 | ]
182 |
183 | # only check the first element, assuming user knows t,c,n,s are required
184 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
185 | raise ValueError("inverted_residual_setting should be non-empty "
186 | "or a 4-element list, got {}".format(inverted_residual_setting))
187 |
188 | # building first layer
189 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
190 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
191 | features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
192 | # building inverted residual blocks
193 | for t, c, n, s in inverted_residual_setting:
194 | output_channel = _make_divisible(c * width_mult, round_nearest)
195 | for i in range(n):
196 | stride = s if i == 0 else 1
197 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer, attention_module=attention_module))
198 | input_channel = output_channel
199 | # building last several layers
200 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
201 | # make it nn.Sequential
202 | self.features = nn.Sequential(*features)
203 |
204 | # building classifier
205 | self.classifier = nn.Sequential(
206 | # nn.Dropout(0.2),
207 | nn.Linear(self.last_channel, num_classes),
208 | )
209 |
210 | # weight initialization
211 | for m in self.modules():
212 | if isinstance(m, nn.Conv2d):
213 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
214 | if m.bias is not None:
215 | nn.init.zeros_(m.bias)
216 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
217 | nn.init.ones_(m.weight)
218 | nn.init.zeros_(m.bias)
219 | elif isinstance(m, nn.Linear):
220 | nn.init.normal_(m.weight, 0, 0.01)
221 | if m.bias is not None:
222 | nn.init.zeros_(m.bias)
223 |
224 | def _forward_impl(self, x: Tensor) -> Tensor:
225 | # This exists since TorchScript doesn't support inheritance, so the superclass method
226 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass
227 | x = self.features(x)
228 | # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
229 | x = nn.functional.adaptive_avg_pool2d(x, (1, 1)).reshape(x.shape[0], -1)
230 | x = self.classifier(x)
231 | return x
232 |
233 | def forward(self, x: Tensor) -> Tensor:
234 | return self._forward_impl(x)
235 |
236 |
237 | def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV2:
238 | """
239 | Constructs a MobileNetV2 architecture from
240 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_.
241 |
242 | Args:
243 | pretrained (bool): If True, returns a model pre-trained on ImageNet
244 | progress (bool): If True, displays a progress bar of the download to stderr
245 | """
246 | model = MobileNetV2(**kwargs)
247 | # if pretrained:
248 | # state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
249 | # progress=progress)
250 | # model.load_state_dict(state_dict)
251 | return model
--------------------------------------------------------------------------------
/networks/cifar/block.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | def conv3x3(in_channels, ou_channels, stride=1):
8 | return nn.Conv2d(in_channels, ou_channels, kernel_size=3, stride=stride, padding=1, bias=False)
9 |
10 |
11 | def conv1x1(in_channels, ou_channels, stride=1):
12 | return nn.Conv2d(in_channels, ou_channels, kernel_size=1, stride=stride, padding=0, bias=False)
13 |
14 |
15 | # Basic Block in ResNet for CIFAR
16 | class BasicBlock(nn.Module):
17 |
18 | EXPANSION = 1
19 |
20 | def __init__(self, in_channels, ou_channels, stride=1, attention_module=None):
21 | super(BasicBlock, self).__init__()
22 |
23 | self.relu = nn.ReLU(inplace=True)
24 |
25 | self.conv1 = conv3x3(in_channels, ou_channels, stride=stride)
26 | self.bn1 = nn.BatchNorm2d(ou_channels)
27 |
28 | self.conv2 = conv3x3(ou_channels, ou_channels * self.EXPANSION, stride=1)
29 | self.bn2 = nn.BatchNorm2d(ou_channels * self.EXPANSION)
30 |
31 | if attention_module is not None:
32 | if type(attention_module) == functools.partial:
33 | module_name = attention_module.func.get_module_name()
34 | else:
35 | module_name = attention_module.get_module_name()
36 |
37 |
38 | if module_name == "simam":
39 | self.conv2 = nn.Sequential(
40 | self.conv2,
41 | attention_module(ou_channels * self.EXPANSION)
42 | )
43 | else:
44 | self.bn2 = nn.Sequential(
45 | self.bn2,
46 | attention_module(ou_channels * self.EXPANSION)
47 | )
48 |
49 | self.shortcut = nn.Sequential()
50 | if stride != 1 or in_channels != ou_channels * self.EXPANSION:
51 | self.shortcut = nn.Sequential(
52 | conv1x1(in_channels, ou_channels * self.EXPANSION, stride=stride),
53 | nn.BatchNorm2d(ou_channels * self.EXPANSION)
54 | )
55 |
56 | def forward(self, x):
57 | out = self.conv1(x)
58 | out = self.bn1(out)
59 | out = self.relu(out)
60 |
61 | out = self.conv2(out)
62 | out = self.bn2(out)
63 |
64 | out += self.shortcut(x)
65 |
66 | return self.relu(out)
67 |
68 | # Bottlenect in ResNet for CIFAR
69 | class BottleNect(nn.Module):
70 |
71 | EXPANSION = 4
72 |
73 | def __init__(self, in_channels, ou_channels, stride=1, attention_module=None):
74 | super(BottleNect, self).__init__()
75 |
76 | self.relu = nn.ReLU(inplace=True)
77 |
78 | self.conv1 = conv1x1(in_channels, ou_channels, stride=1)
79 | self.bn1 = nn.BatchNorm2d(ou_channels)
80 |
81 | self.conv2 = conv3x3(ou_channels, ou_channels, stride=stride)
82 | self.bn2 = nn.BatchNorm2d(ou_channels)
83 |
84 | self.conv3 = conv1x1(ou_channels, ou_channels * self.EXPANSION, stride=1)
85 | self.bn3 = nn.BatchNorm2d(ou_channels * self.EXPANSION)
86 |
87 | if attention_module is not None:
88 | if type(attention_module) == functools.partial:
89 | module_name = attention_module.func.get_module_name()
90 | else:
91 | module_name = attention_module.get_module_name()
92 |
93 | if module_name == "simam":
94 | self.conv2 = nn.Sequential(
95 | self.conv2,
96 | attention_module(ou_channels * self.EXPANSION)
97 | )
98 | else:
99 | self.bn3 = nn.Sequential(
100 | self.bn3,
101 | attention_module(ou_channels * self.EXPANSION)
102 | )
103 |
104 | self.shortcut = nn.Sequential()
105 | if stride != 1 or in_channels != ou_channels * self.EXPANSION:
106 | self.shortcut = nn.Sequential(
107 | conv1x1(in_channels, ou_channels * self.EXPANSION, stride=stride),
108 | nn.BatchNorm2d(ou_channels * self.EXPANSION)
109 | )
110 |
111 | def forward(self, x):
112 | out = self.conv1(x)
113 | out = self.bn1(out)
114 | out = self.relu(out)
115 |
116 | out = self.conv2(out)
117 | out = self.bn2(out)
118 | out = self.relu(out)
119 |
120 | out = self.conv3(out)
121 | out = self.bn3(out)
122 |
123 | out += self.shortcut(x)
124 |
125 | return self.relu(out)
126 |
127 |
128 |
129 | # PreBasic Block in ResNet for CIFAR
130 | class PreBasicBlock(nn.Module):
131 |
132 | EXPANSION = 1
133 |
134 | def __init__(self, in_channels, ou_channels, stride=1, attention_module=None):
135 | super(PreBasicBlock, self).__init__()
136 |
137 | self.relu = nn.ReLU(inplace=True)
138 |
139 | self.bn1 = nn.BatchNorm2d(in_channels)
140 | self.conv1 = conv3x3(in_channels, ou_channels, stride=stride)
141 |
142 | self.bn2 = nn.BatchNorm2d(ou_channels)
143 | self.conv2 = conv3x3(ou_channels, ou_channels * self.EXPANSION, stride=1)
144 |
145 | if attention_module is not None:
146 | self.conv2 = nn.Sequential(
147 | self.conv2,
148 | attention_module(ou_channels * self.EXPANSION)
149 | )
150 |
151 | self.shortcut = nn.Sequential()
152 | if stride != 1 or in_channels != ou_channels * self.EXPANSION:
153 | self.shortcut = nn.Sequential(
154 | conv1x1(in_channels, ou_channels * self.EXPANSION, stride=stride)
155 | )
156 |
157 | def forward(self, x):
158 | out = self.bn1(x)
159 | out = self.relu(out)
160 | out = self.conv1(out)
161 |
162 | out = self.bn2(out)
163 | out = self.relu(out)
164 | out = self.conv2(out)
165 |
166 | out = out + self.shortcut(x)
167 | return out
168 |
169 |
170 | # Bottlenect in ResNet for CIFAR
171 | class PreBottleNect(nn.Module):
172 |
173 | EXPANSION = 4
174 |
175 | def __init__(self, in_channels, ou_channels, stride=1, attention_module=None):
176 | super(PreBottleNect, self).__init__()
177 |
178 | self.relu = nn.ReLU(inplace=True)
179 |
180 | self.bn1 = nn.BatchNorm2d(in_channels)
181 | self.conv1 = conv1x1(in_channels, ou_channels, stride=1)
182 |
183 | self.bn2 = nn.BatchNorm2d(ou_channels)
184 | self.conv2 = conv3x3(ou_channels, ou_channels, stride=stride)
185 |
186 | self.bn3 = nn.BatchNorm2d(ou_channels)
187 | self.conv3 = conv1x1(ou_channels, ou_channels * self.EXPANSION, stride=1)
188 |
189 | if attention_module is not None:
190 | if type(attention_module) == functools.partial:
191 | module_name = attention_module.func.get_module_name()
192 | else:
193 | module_name = attention_module.get_module_name()
194 |
195 | if module_name == "simam":
196 | self.conv2 = nn.Sequential(
197 | self.conv2,
198 | attention_module(ou_channels * self.EXPANSION)
199 | )
200 | else:
201 | self.conv3 = nn.Sequential(
202 | self.conv3,
203 | attention_module(ou_channels * self.EXPANSION)
204 | )
205 |
206 | self.shortcut = nn.Sequential()
207 | if stride != 1 or in_channels != ou_channels * self.EXPANSION:
208 | self.shortcut = nn.Sequential(
209 | conv1x1(in_channels, ou_channels * self.EXPANSION, stride=stride)
210 | )
211 |
212 | def forward(self, x):
213 | out = self.bn1(x)
214 | out = self.relu(out)
215 | out = self.conv1(out)
216 |
217 | out = self.bn2(out)
218 | out = self.relu(out)
219 | out = self.conv2(out)
220 |
221 | out = self.bn3(out)
222 | out = self.relu(out)
223 | out = self.conv3(out)
224 |
225 | out = out + self.shortcut(x)
226 |
227 | return out
228 |
229 |
230 | class WideBasicBlock(nn.Module):
231 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0, attention_module=None):
232 | super(WideBasicBlock, self).__init__()
233 | self.bn1 = nn.BatchNorm2d(in_planes)
234 | self.relu1 = nn.ReLU(inplace=True)
235 | self.conv1 = conv3x3(in_planes, out_planes, stride=stride)
236 |
237 | self.bn2 = nn.BatchNorm2d(out_planes)
238 | self.relu2 = nn.ReLU(inplace=True)
239 | self.conv2 = conv3x3(out_planes, out_planes, stride=1)
240 |
241 | if attention_module is not None:
242 | self.conv2 = nn.Sequential(
243 | self.conv2,
244 | attention_module(out_planes)
245 | )
246 | self.droprate = dropRate
247 | self.equalInOut = (in_planes == out_planes)
248 | self.convShortcut = (not self.equalInOut) and conv1x1(in_planes, out_planes, stride=stride) or None
249 | def forward(self, x):
250 | if not self.equalInOut:
251 | x = self.relu1(self.bn1(x))
252 | else:
253 | out = self.relu1(self.bn1(x))
254 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
255 | if self.droprate > 0:
256 | out = F.dropout(out, p=self.droprate, training=self.training)
257 | out = self.conv2(out)
258 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
259 |
260 |
261 | class InvertedResidualBlock(nn.Module):
262 | '''expand + depthwise + pointwise'''
263 | def __init__(self, in_planes, out_planes, expansion, stride, attention_module=None):
264 | super(InvertedResidualBlock, self).__init__()
265 | self.stride = stride
266 |
267 | planes = expansion * in_planes
268 | self.conv1 = conv1x1(in_planes, planes, stride=1)
269 | self.bn1 = nn.BatchNorm2d(planes)
270 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)
271 | self.bn2 = nn.BatchNorm2d(planes)
272 | self.conv3 = conv1x1(planes, out_planes, stride=1)
273 | self.bn3 = nn.BatchNorm2d(out_planes)
274 |
275 | self.relu = nn.ReLU(inplace=True)
276 |
277 | if attention_module is not None:
278 | if type(attention_module) == functools.partial:
279 | module_name = attention_module.func.get_module_name()
280 | else:
281 | module_name = attention_module.get_module_name()
282 |
283 | if module_name == "simam":
284 | self.conv2 = nn.Sequential(
285 | self.conv2,
286 | attention_module(planes)
287 | )
288 | else:
289 | self.bn3 = nn.Sequential(
290 | self.bn3,
291 | attention_module(out_planes)
292 | )
293 |
294 | self.shortcut = nn.Sequential()
295 | if stride == 1 and in_planes != out_planes:
296 | self.shortcut = nn.Sequential(
297 | conv1x1(in_planes, out_planes, stride=1),
298 | nn.BatchNorm2d(out_planes),
299 | )
300 |
301 | def forward(self, x):
302 | out = self.conv1(x)
303 | out = self.bn1(out)
304 | out = self.relu(out)
305 |
306 | out = self.conv2(out)
307 | out = self.bn2(out)
308 | out = self.relu(out)
309 |
310 | out = self.conv3(out)
311 | out = self.bn3(out)
312 |
313 | out = out + self.shortcut(x) if self.stride == 1 else out
314 |
315 | return out
--------------------------------------------------------------------------------
/mmdetection/mmdet/models/backbones/resnet_simam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import functools
4 |
5 | from mmcv.cnn import constant_init, kaiming_init
6 | from mmcv.runner import load_checkpoint
7 | from mmdet.utils import get_root_logger
8 | from torch.nn.modules.batchnorm import _BatchNorm
9 | from ..builder import BACKBONES
10 | from .attentions import simam_module
11 |
12 |
13 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
14 | """3x3 convolution with padding"""
15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
16 | padding=dilation, groups=groups, bias=False, dilation=dilation)
17 |
18 |
19 | def conv1x1(in_planes, out_planes, stride=1):
20 | """1x1 convolution"""
21 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
22 |
23 |
24 | class BasicBlock(nn.Module):
25 | expansion = 1
26 |
27 | def __init__(self,
28 | inplanes,
29 | planes,
30 | stride=1,
31 | downsample=None,
32 | groups=1,
33 | base_width=64,
34 | dilation=1,
35 | norm_layer=None,
36 | attention_module=None):
37 |
38 | super(BasicBlock, self).__init__()
39 | if norm_layer is None:
40 | norm_layer = nn.BatchNorm2d
41 | if groups != 1 or base_width != 64:
42 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
43 | if dilation > 1:
44 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
45 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
46 | self.conv1 = conv3x3(inplanes, planes, stride)
47 | self.bn1 = norm_layer(planes)
48 | self.relu = nn.ReLU(inplace=True)
49 | self.conv2 = conv3x3(planes, planes)
50 | self.bn2 = norm_layer(planes)
51 | self.downsample = downsample
52 | self.stride = stride
53 |
54 | if attention_module is not None:
55 | if type(attention_module) == functools.partial:
56 | module_name = attention_module.func.get_module_name()
57 | else:
58 | module_name = attention_module.get_module_name()
59 |
60 | if module_name == "simam":
61 | self.conv2 = nn.Sequential(
62 | self.conv2,
63 | attention_module(planes)
64 | )
65 | else:
66 | self.bn2 = nn.Sequential(
67 | self.bn2,
68 | attention_module(planes)
69 | )
70 |
71 | def forward(self, x):
72 | identity = x
73 |
74 | out = self.conv1(x)
75 | out = self.bn1(out)
76 | out = self.relu(out)
77 |
78 | out = self.conv2(out)
79 | out = self.bn2(out)
80 |
81 | if self.downsample is not None:
82 | identity = self.downsample(x)
83 |
84 | out += identity
85 | out = self.relu(out)
86 |
87 | return out
88 |
89 |
90 | class Bottleneck(nn.Module):
91 |
92 | expansion = 4
93 |
94 | def __init__(self,
95 | inplanes,
96 | planes,
97 | stride=1,
98 | downsample=None,
99 | groups=1,
100 | base_width=64,
101 | dilation=1,
102 | norm_layer=None,
103 | attention_module=None):
104 |
105 | super(Bottleneck, self).__init__()
106 | if norm_layer is None:
107 | norm_layer = nn.BatchNorm2d
108 | width = int(planes * (base_width / 64.)) * groups
109 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
110 | self.conv1 = conv1x1(inplanes, width)
111 | self.bn1 = norm_layer(width)
112 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
113 | self.bn2 = norm_layer(width)
114 | self.conv3 = conv1x1(width, planes * self.expansion)
115 | self.bn3 = norm_layer(planes * self.expansion)
116 | self.relu = nn.ReLU(inplace=True)
117 | self.downsample = downsample
118 | self.stride = stride
119 |
120 | if attention_module is not None:
121 | if type(attention_module) == functools.partial:
122 | module_name = attention_module.func.get_module_name()
123 | else:
124 | module_name = attention_module.get_module_name()
125 |
126 | if module_name == "simam":
127 | self.conv2 = nn.Sequential(
128 | self.conv2,
129 | attention_module(width)
130 | )
131 | else:
132 | self.bn3 = nn.Sequential(
133 | self.bn3,
134 | attention_module(planes * self.expansion)
135 | )
136 |
137 | def forward(self, x):
138 | identity = x
139 |
140 | out = self.conv1(x)
141 | out = self.bn1(out)
142 | out = self.relu(out)
143 |
144 | out = self.conv2(out)
145 | out = self.bn2(out)
146 | out = self.relu(out)
147 |
148 | out = self.conv3(out)
149 | out = self.bn3(out)
150 |
151 | if self.downsample is not None:
152 | identity = self.downsample(x)
153 |
154 | out += identity
155 | out = self.relu(out)
156 |
157 | return out
158 |
159 |
160 | @BACKBONES.register_module()
161 | class ResNetAM(nn.Module):
162 |
163 | arch_settings = {
164 | 18: (BasicBlock, (2, 2, 2, 2)),
165 | 34: (BasicBlock, (3, 4, 6, 3)),
166 | 50: (Bottleneck, (3, 4, 6, 3)),
167 | 101: (Bottleneck, (3, 4, 23, 3)),
168 | 152: (Bottleneck, (3, 8, 36, 3))
169 | }
170 |
171 | def __init__(self,
172 | depth,
173 | groups=1,
174 | width_per_group=64,
175 | replace_stride_with_dilation=None,
176 | norm_layer=None,
177 | norm_eval=True,
178 | frozen_stages=-1,
179 | attention_type="none",
180 | attention_param=None,
181 | zero_init_residual=False):
182 | super(ResNetAM, self).__init__()
183 | if depth not in self.arch_settings:
184 | raise KeyError(f'invalid depth {depth} for resnet')
185 |
186 | if norm_layer is None:
187 | norm_layer = nn.BatchNorm2d
188 | self._norm_layer = norm_layer
189 |
190 | self.inplanes = 64
191 | self.dilation = 1
192 | self.norm_eval = norm_eval
193 | self.frozen_stages = frozen_stages
194 | self.zero_init_residual = zero_init_residual
195 | block, stage_blocks = self.arch_settings[depth]
196 |
197 | if attention_type == "simam":
198 | attention_module = functools.partial(simam_module, e_lambda=attention_param)
199 | else:
200 | attention_module = None
201 |
202 | if replace_stride_with_dilation is None:
203 | # each element in the tuple indicates if we should replace
204 | # the 2x2 stride with a dilated convolution instead
205 | replace_stride_with_dilation = [False, False, False]
206 | if len(replace_stride_with_dilation) != 3:
207 | raise ValueError("replace_stride_with_dilation should be None "
208 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
209 | self.groups = groups
210 | self.base_width = width_per_group
211 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
212 | bias=False)
213 | self.bn1 = norm_layer(self.inplanes)
214 | self.relu1 = nn.ReLU(inplace=True)
215 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
216 |
217 | self.layer1 = self._make_layer(block, 64, stage_blocks[0],
218 | attention_module=attention_module)
219 |
220 | self.layer2 = self._make_layer(block, 128, stage_blocks[1], stride=2,
221 | dilate=replace_stride_with_dilation[0],
222 | attention_module=attention_module)
223 |
224 | self.layer3 = self._make_layer(block, 256, stage_blocks[2], stride=2,
225 | dilate=replace_stride_with_dilation[1],
226 | attention_module=attention_module)
227 |
228 | self.layer4 = self._make_layer(block, 512, stage_blocks[3], stride=2,
229 | dilate=replace_stride_with_dilation[2],
230 | attention_module=attention_module)
231 |
232 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
233 | self.fc = nn.Linear(512 * block.expansion, 1000)
234 |
235 |
236 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False, attention_module=None):
237 | norm_layer = self._norm_layer
238 | downsample = None
239 | previous_dilation = self.dilation
240 | if dilate:
241 | self.dilation *= stride
242 | stride = 1
243 | if stride != 1 or self.inplanes != planes * block.expansion:
244 | downsample = nn.Sequential(
245 | conv1x1(self.inplanes, planes * block.expansion, stride),
246 | norm_layer(planes * block.expansion),
247 | )
248 |
249 | layers = []
250 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
251 | self.base_width, previous_dilation, norm_layer, attention_module))
252 | self.inplanes = planes * block.expansion
253 | for _ in range(1, blocks):
254 | layers.append(block(self.inplanes, planes, groups=self.groups,
255 | base_width=self.base_width, dilation=self.dilation,
256 | norm_layer=norm_layer, attention_module=attention_module))
257 |
258 | return nn.Sequential(*layers)
259 |
260 |
261 | def _freeze_stages(self):
262 | if self.frozen_stages >= 0:
263 | self.bn1.eval()
264 | for m in [self.conv1, self.bn1]:
265 | for param in m.parameters():
266 | param.requires_grad = False
267 |
268 | for i in range(1, self.frozen_stages + 1):
269 | m = getattr(self, f'layer{i}')
270 | m.eval()
271 | for param in m.parameters():
272 | param.requires_grad = False
273 |
274 | def init_weights(self, pretrained=None):
275 | """Initialize the weights in backbone.
276 |
277 | Args:
278 | pretrained (str, optional): Path to pre-trained weights.
279 | Defaults to None.
280 | """
281 |
282 | self.fc = None
283 | self.avgpool = None
284 | if isinstance(pretrained, str):
285 | logger = get_root_logger()
286 | load_checkpoint(self, pretrained, strict=False, logger=logger)
287 | elif pretrained is None:
288 | for m in self.modules():
289 | if isinstance(m, nn.Conv2d):
290 | kaiming_init(m)
291 | elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
292 | constant_init(m, 1)
293 |
294 | if self.zero_init_residual:
295 | for m in self.modules():
296 | if isinstance(m, Bottleneck):
297 | constant_init(m.norm3, 0)
298 | elif isinstance(m, BasicBlock):
299 | constant_init(m.norm2, 0)
300 | else:
301 | raise TypeError('pretrained must be a str or None')
302 |
303 | def forward(self, x):
304 | # See note [TorchScript super()]
305 | x = self.conv1(x)
306 | x = self.bn1(x)
307 | x = self.relu1(x)
308 | x = self.maxpool(x)
309 |
310 | outs = []
311 |
312 | x = self.layer1(x)
313 | outs.append(x)
314 |
315 | x = self.layer2(x)
316 | outs.append(x)
317 |
318 | x = self.layer3(x)
319 | outs.append(x)
320 |
321 | x = self.layer4(x)
322 | outs.append(x)
323 |
324 | return tuple(outs)
325 |
326 | def train(self, mode=True):
327 | """Convert the model into training mode while keep normalization layer
328 | freezed."""
329 | super(ResNetAM, self).train(mode)
330 | self._freeze_stages()
331 | if mode and self.norm_eval:
332 | for m in self.modules():
333 | # trick: eval have effect on BatchNorm only
334 | if isinstance(m, _BatchNorm):
335 | m.eval()
--------------------------------------------------------------------------------
/main_cifar.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | # system lib
4 | import os
5 | import time
6 | import sys
7 | import argparse
8 | # numerical libs
9 | import random
10 | import torch
11 | import torch.nn as nn
12 | import torch.optim as optim
13 | import torch.backends.cudnn as cudnn
14 | import torch.nn.functional as F
15 | from torchvision import datasets, transforms
16 | # models
17 | from thop import profile
18 | from util import AverageMeter, ProgressMeter, accuracy, parse_gpus
19 | from checkpoint import save_checkpoint, load_checkpoint
20 | from networks.cifar import create_net
21 |
22 |
23 | def adjust_learning_rate(optimizer, epoch, warmup=False):
24 | """Adjust the learning rate"""
25 | if epoch <= 81:
26 | lr = 0.01 if warmup and epoch == 0 else args.base_lr
27 | elif epoch <= 122:
28 | lr = args.base_lr * 0.1
29 | else:
30 | lr = args.base_lr * 0.01
31 |
32 | for param_group in optimizer.param_groups:
33 | param_group["lr"] = lr
34 |
35 |
36 | def train(net, optimizer, epoch, data_loader, args):
37 |
38 | learning_rate = optimizer.param_groups[0]["lr"]
39 |
40 | batch_time = AverageMeter('Time', ':6.3f')
41 | data_time = AverageMeter('Data', ':6.3f')
42 | losses = AverageMeter('Loss', ':.4f')
43 | top1 = AverageMeter('Accuracy', ':4.2f')
44 | progress = ProgressMeter(
45 | len(data_loader),
46 | [batch_time, data_time, losses, top1],
47 | prefix="Epoch (Train LR {:6.4f}): [{}] ".format(learning_rate, epoch))
48 |
49 | net.train()
50 |
51 | tic = time.time()
52 | for batch_idx, (data, target) in enumerate(data_loader):
53 |
54 | data, target = data.to(args.device, non_blocking=True), target.to(args.device, non_blocking=True)
55 |
56 | data_time.update(time.time() - tic)
57 |
58 | optimizer.zero_grad()
59 | output = net(data)
60 | loss = F.cross_entropy(output, target)
61 | loss.backward()
62 | optimizer.step()
63 |
64 | acc = accuracy(output, target)
65 | losses.update(loss.item(), data.size(0))
66 | top1.update(acc[0].item(), data.size(0))
67 |
68 | batch_time.update(time.time() - tic)
69 | tic = time.time()
70 |
71 | if (batch_idx+1) % args.disp_iter == 0 or (batch_idx+1) == len(data_loader):
72 | epoch_msg = progress.get_message(batch_idx+1)
73 | print(epoch_msg)
74 |
75 | args.log_file.write(epoch_msg + "\n")
76 |
77 | def validate(net, epoch, data_loader, args):
78 |
79 | batch_time = AverageMeter('Time', ':6.3f')
80 | data_time = AverageMeter('Data', ':6.3f')
81 | losses = AverageMeter('Loss', ':.4f')
82 | top1 = AverageMeter('Accuracy', ':4.2f')
83 | progress = ProgressMeter(
84 | len(data_loader),
85 | [batch_time, data_time, losses, top1],
86 | prefix="Epoch (Valid LR {:6.4f}): [{}] ".format(0, epoch))
87 |
88 | net.eval()
89 |
90 | with torch.no_grad():
91 | tic = time.time()
92 | for batch_idx, (data, target) in enumerate(data_loader):
93 |
94 | data, target = data.to(args.device, non_blocking=True), target.to(args.device, non_blocking=True)
95 |
96 | data_time.update(time.time() - tic)
97 |
98 | output = net(data)
99 | loss = F.cross_entropy(output, target)
100 |
101 | acc = accuracy(output, target)
102 | losses.update(loss.item(), data.size(0))
103 | top1.update(acc[0].item(), data.size(0))
104 |
105 | batch_time.update(time.time() - tic)
106 | tic = time.time()
107 |
108 | if (batch_idx+1) % args.disp_iter == 0 or (batch_idx+1) == len(data_loader):
109 | epoch_msg = progress.get_message(batch_idx+1)
110 | print(epoch_msg)
111 |
112 | args.log_file.write(epoch_msg + "\n")
113 |
114 | print('-------- Mean Accuracy {top1.avg:.3f} --------'.format(top1=top1))
115 |
116 | return top1.avg
117 |
118 | def main(args):
119 |
120 | if len(args.gpu_ids) > 0:
121 | assert(torch.cuda.is_available())
122 | cudnn.benchmark = True
123 | kwargs = {"num_workers": args.workers, "pin_memory": True}
124 | args.device = torch.device("cuda:{}".format(args.gpu_ids[0]))
125 | else:
126 | kwargs = {}
127 | args.device = torch.device("cpu")
128 |
129 | normlizer = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
130 |
131 | print("Building dataset: " + args.dataset)
132 |
133 | if args.dataset == "cifar10":
134 | args.num_class = 10
135 | train_loader = torch.utils.data.DataLoader(
136 | datasets.CIFAR10(args.dataset_dir, train=True, download=True,
137 | transform=transforms.Compose([
138 | transforms.Pad(4),
139 | transforms.RandomCrop(32),
140 | transforms.RandomHorizontalFlip(),
141 | transforms.ToTensor(),
142 | normlizer])),
143 | batch_size=args.batch_size, shuffle=True, **kwargs)
144 |
145 | test_loader = torch.utils.data.DataLoader(
146 | datasets.CIFAR10(args.dataset_dir, train=False, transform=transforms.Compose([
147 | transforms.ToTensor(),
148 | normlizer])),
149 | batch_size=100, shuffle=False, **kwargs)
150 |
151 | elif args.dataset == "cifar100":
152 | args.num_class = 100
153 | train_loader = torch.utils.data.DataLoader(
154 | datasets.CIFAR100(args.dataset_dir, train=True, download=True,
155 | transform=transforms.Compose([
156 | transforms.Pad(4),
157 | transforms.RandomCrop(32),
158 | transforms.RandomHorizontalFlip(),
159 | transforms.ToTensor(),
160 | normlizer])),
161 | batch_size=args.batch_size, shuffle=True, **kwargs)
162 |
163 | test_loader = torch.utils.data.DataLoader(
164 | datasets.CIFAR100(args.dataset_dir, train=False, transform=transforms.Compose([
165 | transforms.ToTensor(),
166 | normlizer])),
167 | batch_size=100, shuffle=False, **kwargs)
168 |
169 | net = create_net(args)
170 |
171 | print(net)
172 |
173 | optimizer = optim.SGD(net.parameters(), lr=args.base_lr, momentum=args.beta1, weight_decay=args.weight_decay)
174 |
175 | if args.resume:
176 | net, optimizer, best_acc, start_epoch = load_checkpoint(args, net, optimizer)
177 | else:
178 | start_epoch = 0
179 | best_acc = 0
180 |
181 | x = torch.randn(1, 3, 32, 32)
182 | flops, params = profile(net, inputs=(x,))
183 |
184 | print("Number of params: %.6fM" % (params / 1e6))
185 | print("Number of FLOPs: %.6fG" % (flops / 1e9))
186 |
187 | args.log_file.write("Network - " + args.arch + "\n")
188 | args.log_file.write("Attention Module - " + args.attention_type + "\n")
189 | args.log_file.write("Params - %.6fM" % (params / 1e6) + "\n")
190 | args.log_file.write("FLOPs - %.6fG" % (flops / 1e9) + "\n")
191 | args.log_file.write("--------------------------------------------------" + "\n")
192 |
193 | if len(args.gpu_ids) > 0:
194 | net.to(args.gpu_ids[0])
195 | net = torch.nn.DataParallel(net, args.gpu_ids) # multi-GPUs
196 |
197 | for epoch in range(start_epoch, args.num_epoch):
198 | # if args.wrn:
199 | # adjust_learning_rate_wrn(optimizer, epoch, args.warmup)
200 | # else:
201 | adjust_learning_rate(optimizer, epoch, args.warmup)
202 |
203 | train(net, optimizer, epoch, train_loader, args)
204 | epoch_acc = validate(net, epoch, test_loader, args)
205 |
206 | is_best = epoch_acc > best_acc
207 | best_acc = max(epoch_acc, best_acc)
208 |
209 | save_checkpoint({
210 | "epoch": epoch + 1,
211 | "arch": args.arch,
212 | "state_dict": net.module.cpu().state_dict(),
213 | "best_acc": best_acc,
214 | "optimizer" : optimizer.state_dict(),
215 | }, is_best, epoch, save_path=args.ckpt)
216 |
217 | net.to(args.device)
218 |
219 | args.log_file.write("--------------------------------------------------" + "\n")
220 |
221 | args.log_file.write("best accuracy %4.2f" % best_acc)
222 |
223 | print("Job Done!")
224 |
225 | if __name__ == "__main__":
226 |
227 | parser = argparse.ArgumentParser(description="CIFAR baseline")
228 |
229 | # Model settings
230 | parser.add_argument("--arch", type=str, default="resnet18",
231 | help="network architecture (default: resnet18)")
232 | parser.add_argument("--num_base_filters", type=int, default=16,
233 | help="network base filer numbers (default: 16)")
234 | parser.add_argument("--expansion", type=float, default=1,
235 | help="expansion factor for the mid-layer in resnet-like")
236 | parser.add_argument("--block_type", type=str, default="basic",
237 | help="building block for network, e.g., basic or bottlenect")
238 | parser.add_argument("--attention_type", type=str, default="none",
239 | help="attention type in building block (possible choices none | se | cbam | simam )")
240 | parser.add_argument("--attention_param", type=float, default=4,
241 | help="attention parameter (reduction in CBAM and SE, e_lambda in simam)")
242 |
243 | # Dataset settings
244 | parser.add_argument("--dataset", type=str, default="cifar10",
245 | help="training dataset (default: cifar10)")
246 | parser.add_argument("--dataset_dir", type=str, default="data",
247 | help="data set path (default: data)")
248 | parser.add_argument("--workers", default=16, type=int,
249 | help="number of data loading works")
250 |
251 | # Optimizion settings
252 | parser.add_argument("--gpu_ids", default="0",
253 | help="gpus to use, e.g. 0-3 or 0,1,2,3")
254 | parser.add_argument("--batch_size", type=int, default=128,
255 | help="batch size for training and validation (default: 128)")
256 | parser.add_argument("--num_epoch", type=int, default=164,
257 | help="number of epochs to train (default: 164)")
258 | parser.add_argument("--resume", default="", type=str,
259 | help="path to checkpoint for continous training (default: none)")
260 | parser.add_argument("--optim", default="SGD",
261 | help="optimizer")
262 | parser.add_argument("--base_lr", type=float, default=0.1,
263 | help="learning rate (default: 0.1)")
264 | parser.add_argument("--beta1", default=0.9, type=float,
265 | help="momentum for sgd, beta1 for adam")
266 | parser.add_argument("--weight_decay", type=float, default=5e-4,
267 | help="SGD weight decay (default: 5e-4)")
268 | parser.add_argument("--warmup", action="store_true",
269 | help="warmup for deeper network")
270 | parser.add_argument("--wrn", action="store_true",
271 | help="wider resnet for training")
272 |
273 | # Misc
274 | parser.add_argument("--seed", type=int, default=1,
275 | help="random seed (default: 1)")
276 | parser.add_argument("--disp_iter", type=int, default=100,
277 | help="frequence to display training status (default: 100)")
278 | parser.add_argument("--ckpt", default="./ckpts/",
279 | help="folder to output checkpoints")
280 |
281 | args = parser.parse_args()
282 | args.gpu_ids = parse_gpus(args.gpu_ids)
283 |
284 | args.ckpt += args.dataset
285 | args.ckpt += "-" + args.arch
286 | args.ckpt += "-" + args.block_type
287 | if args.attention_type.lower() != "none":
288 | args.ckpt += "-" + args.attention_type
289 | if args.attention_type.lower() != "none":
290 | args.ckpt += "-param" + str(args.attention_param)
291 | args.ckpt += "-nfilters" + str(args.num_base_filters)
292 | args.ckpt += "-expansion" + str(args.expansion)
293 | args.ckpt += "-baselr" + str(args.base_lr)
294 | args.ckpt += "-rseed" + str(args.seed)
295 | for key, val in vars(args).items():
296 | print("{:16} {}".format(key, val))
297 |
298 | if not os.path.isdir(args.ckpt):
299 | os.makedirs(args.ckpt)
300 |
301 | # write to file
302 | args.log_file = open(os.path.join(args.ckpt, "log_file.txt"), mode="w")
303 |
304 | random.seed(args.seed)
305 | torch.manual_seed(args.seed)
306 |
307 | main(args)
308 |
309 | args.log_file.close()
310 |
--------------------------------------------------------------------------------
/networks/imagenet/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import functools
4 | # from .utils import load_state_dict_from_url
5 | # note the refinement module
6 |
7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
8 | 'resnet152']
9 |
10 |
11 | model_urls = {
12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
17 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
18 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
19 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
20 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
21 | }
22 |
23 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
24 | """3x3 convolution with padding"""
25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
26 | padding=dilation, groups=groups, bias=False, dilation=dilation)
27 |
28 |
29 | def conv1x1(in_planes, out_planes, stride=1):
30 | """1x1 convolution"""
31 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
32 |
33 |
34 | class BasicBlock(nn.Module):
35 | expansion = 1
36 |
37 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
38 | base_width=64, dilation=1, norm_layer=None, attention_module=None):
39 | super(BasicBlock, self).__init__()
40 | if norm_layer is None:
41 | norm_layer = nn.BatchNorm2d
42 | if groups != 1 or base_width != 64:
43 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
44 | if dilation > 1:
45 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
46 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
47 | self.conv1 = conv3x3(inplanes, planes, stride)
48 | self.bn1 = norm_layer(planes)
49 | self.relu = nn.ReLU(inplace=True)
50 | self.conv2 = conv3x3(planes, planes)
51 | self.bn2 = norm_layer(planes)
52 | self.downsample = downsample
53 | self.stride = stride
54 |
55 | if attention_module is not None:
56 | if type(attention_module) == functools.partial:
57 | module_name = attention_module.func.get_module_name()
58 | else:
59 | module_name = attention_module.get_module_name()
60 |
61 |
62 | if module_name == "simam":
63 | self.conv2 = nn.Sequential(
64 | self.conv2,
65 | attention_module(planes)
66 | )
67 | else:
68 | self.bn2 = nn.Sequential(
69 | self.bn2,
70 | attention_module(planes)
71 | )
72 |
73 | def forward(self, x):
74 | identity = x
75 |
76 | out = self.conv1(x)
77 | out = self.bn1(out)
78 | out = self.relu(out)
79 |
80 | out = self.conv2(out)
81 | out = self.bn2(out)
82 |
83 | if self.downsample is not None:
84 | identity = self.downsample(x)
85 |
86 | out += identity
87 | out = self.relu(out)
88 |
89 | return out
90 |
91 |
92 | class Bottleneck(nn.Module):
93 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
94 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
95 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
96 | # This variant is also known as ResNet V1.5 and improves accuracy according to
97 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
98 |
99 | expansion = 4
100 |
101 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
102 | base_width=64, dilation=1, norm_layer=None, attention_module=None):
103 | super(Bottleneck, self).__init__()
104 | if norm_layer is None:
105 | norm_layer = nn.BatchNorm2d
106 | width = int(planes * (base_width / 64.)) * groups
107 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
108 | self.conv1 = conv1x1(inplanes, width)
109 | self.bn1 = norm_layer(width)
110 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
111 | self.bn2 = norm_layer(width)
112 | self.conv3 = conv1x1(width, planes * self.expansion)
113 | self.bn3 = norm_layer(planes * self.expansion)
114 | self.relu = nn.ReLU(inplace=True)
115 | self.downsample = downsample
116 | self.stride = stride
117 |
118 | if attention_module is not None:
119 | if type(attention_module) == functools.partial:
120 | module_name = attention_module.func.get_module_name()
121 | else:
122 | module_name = attention_module.get_module_name()
123 |
124 | if module_name == "simam":
125 | self.conv2 = nn.Sequential(
126 | self.conv2,
127 | attention_module(width)
128 | )
129 | else:
130 | self.bn3 = nn.Sequential(
131 | self.bn3,
132 | attention_module(planes * self.expansion)
133 | )
134 |
135 | def forward(self, x):
136 | identity = x
137 |
138 | out = self.conv1(x)
139 | out = self.bn1(out)
140 | out = self.relu(out)
141 |
142 | out = self.conv2(out)
143 | out = self.bn2(out)
144 | out = self.relu(out)
145 |
146 | out = self.conv3(out)
147 | out = self.bn3(out)
148 |
149 | if self.downsample is not None:
150 | identity = self.downsample(x)
151 |
152 | out += identity
153 | out = self.relu(out)
154 |
155 | return out
156 |
157 |
158 | class ResNet(nn.Module):
159 |
160 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
161 | groups=1, width_per_group=64, replace_stride_with_dilation=None,
162 | norm_layer=None, attention_module=None,
163 | deep_stem=False, stem_width=32, avg_down=False):
164 | super(ResNet, self).__init__()
165 | if norm_layer is None:
166 | norm_layer = nn.BatchNorm2d
167 |
168 | self._norm_layer = norm_layer
169 |
170 | self.inplanes = stem_width*2 if deep_stem else 64
171 |
172 | self.dilation = 1
173 |
174 | self.groups = groups
175 | self.base_width = width_per_group
176 |
177 | if replace_stride_with_dilation is None:
178 | # each element in the tuple indicates if we should replace
179 | # the 2x2 stride with a dilated convolution instead
180 | replace_stride_with_dilation = [False, False, False]
181 | if len(replace_stride_with_dilation) != 3:
182 | raise ValueError("replace_stride_with_dilation should be None "
183 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
184 |
185 | if deep_stem:
186 | self.conv1 = nn.Sequential(
187 | conv3x3(3, stem_width, stride=2),
188 | norm_layer(stem_width),
189 | nn.ReLU(),
190 | conv3x3(stem_width, stem_width, stride=1),
191 | norm_layer(stem_width),
192 | nn.ReLU(),
193 | conv3x3(stem_width, self.inplanes, stride=1),
194 | )
195 | else:
196 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
197 |
198 | self.bn1 = norm_layer(self.inplanes if not deep_stem else stem_width*2)
199 | self.relu = nn.ReLU(inplace=True)
200 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
201 |
202 | self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down,
203 | attention_module=attention_module)
204 |
205 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down,
206 | dilate=replace_stride_with_dilation[0],
207 | attention_module=attention_module)
208 |
209 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, avg_down=avg_down,
210 | dilate=replace_stride_with_dilation[1],
211 | attention_module=attention_module)
212 |
213 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, avg_down=avg_down,
214 | dilate=replace_stride_with_dilation[2],
215 | attention_module=attention_module)
216 |
217 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
218 | self.fc = nn.Linear(512 * block.expansion, num_classes)
219 |
220 | # print(self.modules)
221 |
222 | for m in self.modules():
223 | # print(m)
224 | if isinstance(m, nn.Conv2d):
225 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
226 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
227 | nn.init.constant_(m.weight, 1)
228 | nn.init.constant_(m.bias, 0)
229 |
230 | # Zero-initialize the last BN in each residual branch,
231 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
232 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
233 | if zero_init_residual:
234 | for m in self.modules():
235 | if isinstance(m, Bottleneck):
236 | nn.init.constant_(m.bn3.weight, 0)
237 | elif isinstance(m, BasicBlock):
238 | nn.init.constant_(m.bn2.weight, 0)
239 |
240 | def _make_layer(self, block, planes, blocks, stride=1, avg_down=False, dilate=False, attention_module=None):
241 | norm_layer = self._norm_layer
242 | downsample = None
243 | previous_dilation = self.dilation
244 | if dilate:
245 | self.dilation *= stride
246 | stride = 1
247 | if stride != 1 or self.inplanes != planes * block.expansion:
248 | if avg_down and stride != 1:
249 | downsample = nn.Sequential(
250 | nn.AvgPool2d(kernel_size=stride, stride=stride, count_include_pad=False, ceil_mode=True),
251 | conv1x1(self.inplanes, planes * block.expansion, 1),
252 | norm_layer(planes * block.expansion)
253 | )
254 | else:
255 | downsample = nn.Sequential(
256 | conv1x1(self.inplanes, planes * block.expansion, stride=stride),
257 | norm_layer(planes * block.expansion)
258 | )
259 |
260 |
261 | layers = []
262 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
263 | self.base_width, previous_dilation, norm_layer, attention_module))
264 | self.inplanes = planes * block.expansion
265 | for _ in range(1, blocks):
266 | layers.append(block(self.inplanes, planes, groups=self.groups,
267 | base_width=self.base_width, dilation=self.dilation,
268 | norm_layer=norm_layer, attention_module=attention_module))
269 |
270 | return nn.Sequential(*layers)
271 |
272 | def _forward_impl(self, x):
273 | # See note [TorchScript super()]
274 | x = self.conv1(x)
275 | x = self.bn1(x)
276 | x = self.relu(x)
277 | x = self.maxpool(x)
278 | x = self.layer1(x)
279 | x = self.layer2(x)
280 | x = self.layer3(x)
281 | x = self.layer4(x)
282 |
283 | x = self.avgpool(x)
284 | x = torch.flatten(x, 1)
285 | x = self.fc(x)
286 |
287 | return x
288 |
289 | def forward(self, x):
290 | return self._forward_impl(x)
291 |
292 |
293 | def _resnet(arch, block, layers, **kwargs):
294 | model = ResNet(block, layers, **kwargs)
295 | # if pretrained:
296 | # state_dict = load_state_dict_from_url(model_urls[arch],
297 | # progress=progress)
298 | # model.load_state_dict(state_dict)
299 | return model
300 |
301 |
302 | def resnet18(**kwargs):
303 | r"""ResNet-18 model from
304 | `"Deep Residual Learning for Image Recognition" `_
305 |
306 | Args:
307 | pretrained (bool): If True, returns a model pre-trained on ImageNet
308 | progress (bool): If True, displays a progress bar of the download to stderr
309 | """
310 |
311 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], **kwargs)
312 |
313 |
314 |
315 | def resnet34(**kwargs):
316 | r"""ResNet-34 model from
317 | `"Deep Residual Learning for Image Recognition" `_
318 |
319 | Args:
320 | pretrained (bool): If True, returns a model pre-trained on ImageNet
321 | progress (bool): If True, displays a progress bar of the download to stderr
322 | """
323 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], **kwargs)
324 |
325 |
326 |
327 | def resnet50(**kwargs):
328 | r"""ResNet-50 model from
329 | `"Deep Residual Learning for Image Recognition" `_
330 |
331 | Args:
332 | pretrained (bool): If True, returns a model pre-trained on ImageNet
333 | progress (bool): If True, displays a progress bar of the download to stderr
334 | """
335 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], **kwargs)
336 |
337 |
338 |
339 | def resnet101(**kwargs):
340 | r"""ResNet-101 model from
341 | `"Deep Residual Learning for Image Recognition" `_
342 |
343 | Args:
344 | pretrained (bool): If True, returns a model pre-trained on ImageNet
345 | progress (bool): If True, displays a progress bar of the download to stderr
346 | """
347 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], **kwargs)
348 |
349 |
350 |
351 | def resnet152(**kwargs):
352 | r"""ResNet-152 model from
353 | `"Deep Residual Learning for Image Recognition" `_
354 |
355 | Args:
356 | pretrained (bool): If True, returns a model pre-trained on ImageNet
357 | progress (bool): If True, displays a progress bar of the download to stderr
358 | """
359 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], **kwargs)
360 |
361 |
362 |
363 | def resnext50_32x4d(**kwargs):
364 | r"""ResNeXt-50 32x4d model from
365 | `"Aggregated Residual Transformation for Deep Neural Networks" `_.
366 | Args:
367 | pretrained (bool): If True, returns a model pre-trained on ImageNet
368 | progress (bool): If True, displays a progress bar of the download to stderr
369 | """
370 | kwargs['groups'] = 32
371 | kwargs['width_per_group'] = 4
372 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], **kwargs)
373 |
374 |
375 | def resnext101_64x4d(**kwargs):
376 | r"""ResNeXt-50 32x4d model from
377 | `"Aggregated Residual Transformation for Deep Neural Networks" `_.
378 | Args:
379 | pretrained (bool): If True, returns a model pre-trained on ImageNet
380 | progress (bool): If True, displays a progress bar of the download to stderr
381 | """
382 | kwargs['groups'] = 64
383 | kwargs['width_per_group'] = 4
384 | return _resnet('resnext101_64x4d', Bottleneck, [3, 4, 23, 3], **kwargs)
385 |
386 |
387 |
388 | def resnet50d(**kwargs):
389 | r"""ResNet-50 model from
390 | `"Deep Residual Learning for Image Recognition" `_
391 |
392 | Args:
393 | pretrained (bool): If True, returns a model pre-trained on ImageNet
394 | progress (bool): If True, displays a progress bar of the download to stderr
395 | """
396 |
397 | return _resnet('resnet50d', Bottleneck, [3, 4, 6, 3],
398 | deep_stem=True, stem_width=32, avg_down=True,
399 | **kwargs)
--------------------------------------------------------------------------------
/main_imagenet.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import shutil
5 | import time
6 | import warnings
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.parallel
11 | import torch.backends.cudnn as cudnn
12 | import torch.distributed as dist
13 | import torch.optim
14 | import torch.multiprocessing as mp
15 | import torch.utils.data
16 | import torch.utils.data.distributed
17 | import torchvision.transforms as transforms
18 | import torchvision.datasets as datasets
19 | import torchvision.models as models
20 | import numpy as np
21 |
22 | from util import AverageMeter, ProgressMeter, accuracy, parse_gpus
23 | from checkpoint import save_checkpoint, load_checkpoint
24 | from thop import profile
25 | from networks.imagenet import create_net
26 |
27 |
28 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
29 | parser.add_argument('data', metavar='DIR',
30 | help='path to dataset')
31 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
32 | help='model architecture (default: resnet18)')
33 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
34 | help='number of data loading workers (default: 4)')
35 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
36 | help='number of total epochs to run')
37 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
38 | help='manual epoch number (useful on restarts)')
39 | parser.add_argument('-b', '--batch-size', default=256, type=int,
40 | metavar='N',
41 | help='mini-batch size (default: 256), this is the total '
42 | 'batch size of all GPUs on the current node when '
43 | 'using Data Parallel or Distributed Data Parallel')
44 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
45 | metavar='LR', help='initial learning rate', dest='lr')
46 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
47 | help='momentum')
48 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
49 | metavar='W', help='weight decay (default: 1e-4)',
50 | dest='weight_decay')
51 | parser.add_argument('-p', '--print-freq', default=10, type=int,
52 | metavar='N', help='print frequency (default: 10)')
53 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
54 | help='path to latest checkpoint (default: none)')
55 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
56 | help='evaluate model on validation set')
57 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
58 | help='use pre-trained model')
59 | parser.add_argument('--world-size', default=-1, type=int,
60 | help='number of nodes for distributed training')
61 | parser.add_argument('--rank', default=-1, type=int,
62 | help='node rank for distributed training')
63 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str,
64 | help='url used to set up distributed training')
65 | parser.add_argument('--dist-backend', default='nccl', type=str,
66 | help='distributed backend')
67 | parser.add_argument('--seed', default=None, type=int,
68 | help='seed for initializing training. ')
69 | parser.add_argument('--gpu', default="0",
70 | help='GPU id to use.')
71 | parser.add_argument('--multiprocessing-distributed', action='store_true',
72 | help='Use multi-processing distributed training to launch '
73 | 'N processes per node, which has N GPUs. This is the '
74 | 'fastest way to use PyTorch for either single node or '
75 | 'multi node data parallel training')
76 |
77 | parser.add_argument("--ckpt", default="./ckpts/",
78 | help="folder to output checkpoints")
79 | parser.add_argument("--attention_type", type=str, default="none",
80 | help="attention type (possible choices none | se | cbam | simam)")
81 | parser.add_argument("--attention_param", type=float, default=4,
82 | help="attention parameter (reduction factor in se and cbam, e_lambda in simam)")
83 | parser.add_argument("--log_freq", type=int, default=500,
84 | help="log frequency to file")
85 | parser.add_argument("--cos_lr", action='store_true',
86 | help='use cosine learning rate')
87 | parser.add_argument("--save_weights", default=None, type=str, metavar='PATH',
88 | help='save weights by CPU for mmdetection')
89 |
90 |
91 | best_acc1 = 0
92 |
93 |
94 | def main():
95 | args = parser.parse_args()
96 |
97 | args.ckpt += "imagenet"
98 | args.ckpt += "-" + args.arch
99 | if args.attention_type.lower() != "none":
100 | args.ckpt += "-" + args.attention_type
101 | if args.attention_type.lower() != "none":
102 | args.ckpt += "-param" + str(args.attention_param)
103 |
104 |
105 | args.gpu = parse_gpus(args.gpu)
106 | if args.gpu is not None:
107 | args.device = torch.device("cuda:{}".format(args.gpu[0]))
108 | else:
109 | args.device = torch.device("cpu")
110 |
111 |
112 | if args.seed is not None:
113 | random.seed(args.seed)
114 | torch.manual_seed(args.seed)
115 | cudnn.deterministic = True
116 | warnings.warn('You have chosen to seed training. '
117 | 'This will turn on the CUDNN deterministic setting, '
118 | 'which can slow down your training considerably! '
119 | 'You may see unexpected behavior when restarting '
120 | 'from checkpoints.')
121 | args.ckpt += '-seed' + str(args.seed)
122 |
123 | if not os.path.isdir(args.ckpt):
124 | os.makedirs(args.ckpt)
125 |
126 | if args.gpu is not None:
127 | warnings.warn('You have chosen a specific GPU. This will completely '
128 | 'disable data parallelism.')
129 |
130 | if args.dist_url == "env://" and args.world_size == -1:
131 | args.world_size = int(os.environ["WORLD_SIZE"])
132 |
133 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
134 |
135 | ngpus_per_node = torch.cuda.device_count()
136 | if args.multiprocessing_distributed:
137 | # Since we have ngpus_per_node processes per node, the total world_size
138 | # needs to be adjusted accordingly
139 | args.world_size = ngpus_per_node * args.world_size
140 | # Use torch.multiprocessing.spawn to launch distributed processes: the
141 | # main_worker process function
142 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
143 | else:
144 | # Simply call main_worker function
145 | main_worker(args.gpu, ngpus_per_node, args)
146 |
147 |
148 | def main_worker(gpu, ngpus_per_node, args):
149 | global best_acc1
150 |
151 | if args.gpu is not None:
152 | print("Use GPU: {} for training".format(args.gpu))
153 |
154 | if args.distributed:
155 | if args.dist_url == "env://" and args.rank == -1:
156 | args.rank = int(os.environ["RANK"])
157 | if args.multiprocessing_distributed:
158 | # For multiprocessing distributed training, rank needs to be the
159 | # global rank among all the processes
160 | args.rank = args.rank * ngpus_per_node + gpu
161 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
162 | world_size=args.world_size, rank=args.rank)
163 | # create model
164 | model = create_net(args)
165 |
166 |
167 | x = torch.randn(1, 3, 224, 224)
168 | flops, params = profile(model, inputs=(x,))
169 |
170 | print("model [%s] - params: %.6fM" % (args.arch, params / 1e6))
171 | print("model [%s] - FLOPs: %.6fG" % (args.arch, flops / 1e9))
172 |
173 | log_file = os.path.join(args.ckpt, "log.txt")
174 |
175 | if os.path.exists(log_file):
176 | args.log_file = open(log_file, mode="a")
177 | else:
178 | args.log_file = open(log_file, mode="w")
179 | args.log_file.write("Network - " + args.arch + "\n")
180 | args.log_file.write("Attention Module - " + args.attention_type + "\n")
181 | args.log_file.write("Params - " % str(params) + "\n")
182 | args.log_file.write("FLOPs - " % str(flops) + "\n")
183 | args.log_file.write("--------------------------------------------------" + "\n")
184 |
185 | args.log_file.close()
186 |
187 |
188 | if not torch.cuda.is_available():
189 | print('using CPU, this will be slow')
190 | elif args.distributed:
191 | # For multiprocessing distributed, DistributedDataParallel constructor
192 | # should always set the single device scope, otherwise,
193 | # DistributedDataParallel will use all available devices.
194 | if args.gpu is not None:
195 | torch.cuda.set_device(args.device)
196 | model.cuda(args.gpu)
197 | # When using a single GPU per process and per
198 | # DistributedDataParallel, we need to divide the batch size
199 | # ourselves based on the total number of GPUs we have
200 | args.batch_size = int(args.batch_size / ngpus_per_node)
201 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
202 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
203 | else:
204 | model.cuda()
205 | # DistributedDataParallel will divide and allocate batch_size to all
206 | # available GPUs if device_ids are not set
207 | model = torch.nn.parallel.DistributedDataParallel(model)
208 | elif args.gpu is not None:
209 | torch.cuda.set_device(args.device)
210 | model = model.to(args.gpu[0])
211 | model = torch.nn.DataParallel(model, args.gpu)
212 |
213 | print(model)
214 |
215 | # define loss function (criterion) and optimizer
216 | criterion = nn.CrossEntropyLoss().cuda(args.gpu)
217 |
218 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
219 | momentum=args.momentum,
220 | weight_decay=args.weight_decay)
221 | if args.resume:
222 | model, optimizer, best_acc1, start_epoch = load_checkpoint(args, model, optimizer)
223 | args.start_epoch = start_epoch
224 |
225 | cudnn.benchmark = True
226 |
227 | # Data loading code
228 | traindir = os.path.join(args.data, 'train')
229 | valdir = os.path.join(args.data, 'val')
230 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
231 | std=[0.229, 0.224, 0.225])
232 |
233 | train_dataset = datasets.ImageFolder(
234 | traindir,
235 | transforms.Compose([
236 | transforms.RandomResizedCrop(224),
237 | transforms.RandomHorizontalFlip(),
238 | transforms.ToTensor(),
239 | normalize,
240 | ]))
241 |
242 | if args.distributed:
243 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
244 | else:
245 | train_sampler = None
246 |
247 | train_loader = torch.utils.data.DataLoader(
248 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
249 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
250 |
251 | val_loader = torch.utils.data.DataLoader(
252 | datasets.ImageFolder(valdir, transforms.Compose([
253 | transforms.Resize(256),
254 | transforms.CenterCrop(224),
255 | transforms.ToTensor(),
256 | normalize,
257 | ])),
258 | batch_size=args.batch_size, shuffle=False,
259 | num_workers=args.workers, pin_memory=True)
260 |
261 | if args.save_weights is not None: # "deparallelize" saved weights
262 | print("=> saving 'deparallelized' weights [%s]" % args.save_weights)
263 | model = model.module
264 | model = model.cpu()
265 | torch.save({'state_dict': model.state_dict()}, args.save_weights, _use_new_zipfile_serialization=False)
266 | return
267 |
268 |
269 | if args.evaluate:
270 | args.log_file = open(log_file, mode="a")
271 | validate(val_loader, model, criterion, args)
272 | args.log_file.close()
273 | return
274 |
275 | if args.cos_lr:
276 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
277 | for epoch in range(args.start_epoch):
278 | scheduler.step()
279 |
280 |
281 | for epoch in range(args.start_epoch, args.epochs):
282 |
283 | args.log_file = open(log_file, mode="a")
284 |
285 | if args.distributed:
286 | train_sampler.set_epoch(epoch)
287 |
288 | if(not args.cos_lr):
289 | adjust_learning_rate(optimizer, epoch, args)
290 | else:
291 | scheduler.step()
292 | print('[%03d] %.5f'%(epoch, scheduler.get_lr()[0]))
293 |
294 |
295 | # train for one epoch
296 | train(train_loader, model, criterion, optimizer, epoch, args)
297 |
298 | # evaluate on validation set
299 | acc1 = validate(val_loader, model, criterion, args)
300 |
301 | # remember best acc@1 and save checkpoint
302 | is_best = acc1 > best_acc1
303 | best_acc1 = max(acc1, best_acc1)
304 |
305 | args.log_file.close()
306 |
307 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
308 | and args.rank % ngpus_per_node == 0):
309 |
310 | save_checkpoint({
311 | "epoch": epoch + 1,
312 | "arch": args.arch,
313 | "state_dict": model.state_dict(),
314 | "best_acc": best_acc1,
315 | "optimizer" : optimizer.state_dict(),
316 | }, is_best, epoch, save_path=args.ckpt)
317 |
318 |
319 | def train(train_loader, model, criterion, optimizer, epoch, args):
320 | batch_time = AverageMeter('Time', ':6.3f')
321 | data_time = AverageMeter('Data', ':6.3f')
322 | losses = AverageMeter('Loss', ':.4e')
323 | top1 = AverageMeter('Acc@1', ':6.2f')
324 | top5 = AverageMeter('Acc@5', ':6.2f')
325 | progress = ProgressMeter(
326 | len(train_loader),
327 | [batch_time, data_time, losses, top1, top5],
328 | prefix="Epoch: [{}]".format(epoch))
329 |
330 | param_groups = optimizer.param_groups[0]
331 | curr_lr = param_groups["lr"]
332 |
333 | # switch to train mode
334 | model.train()
335 |
336 | end = time.time()
337 | for i, (images, target) in enumerate(train_loader):
338 | # measure data loading time
339 | data_time.update(time.time() - end)
340 |
341 | if args.gpu is not None:
342 | images = images.to(args.device, non_blocking=True)
343 | if torch.cuda.is_available():
344 | target = target.to(args.device, non_blocking=True)
345 |
346 | # compute output
347 | output = model(images)
348 | loss = criterion(output, target)
349 |
350 | # measure accuracy and record loss
351 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
352 | losses.update(loss.item(), images.size(0))
353 | top1.update(acc1[0], images.size(0))
354 | top5.update(acc5[0], images.size(0))
355 |
356 | # compute gradient and do SGD step
357 | optimizer.zero_grad()
358 | loss.backward()
359 | optimizer.step()
360 |
361 | # measure elapsed time
362 | batch_time.update(time.time() - end)
363 | end = time.time()
364 |
365 | if i % args.print_freq == 0:
366 | epoch_msg = progress.get_message(i)
367 | epoch_msg += ("\tLr {:.4f}".format(curr_lr))
368 | print(epoch_msg)
369 |
370 | if i % args.log_freq == 0:
371 | args.log_file.write(epoch_msg + "\n")
372 |
373 |
374 | def validate(val_loader, model, criterion, args):
375 | batch_time = AverageMeter('Time', ':6.3f')
376 | losses = AverageMeter('Loss', ':.4e')
377 | top1 = AverageMeter('Acc@1', ':6.2f')
378 | top5 = AverageMeter('Acc@5', ':6.2f')
379 |
380 | progress = ProgressMeter(
381 | len(val_loader),
382 | [batch_time, losses, top1, top5],
383 | prefix='Test: ')
384 |
385 | # switch to evaluate mode
386 | model.eval()
387 |
388 | with torch.no_grad():
389 | end = time.time()
390 | for i, (images, target) in enumerate(val_loader):
391 |
392 | if args.gpu is not None:
393 | images = images.to(args.device, non_blocking=True)
394 | if torch.cuda.is_available():
395 | target = target.to(args.device, non_blocking=True)
396 |
397 | # compute outputs
398 | output = model(images)
399 | loss = criterion(output, target)
400 |
401 | # measure accuracy and record loss
402 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
403 | losses.update(loss.item(), images.size(0))
404 | top1.update(acc1[0], images.size(0))
405 | top5.update(acc5[0], images.size(0))
406 |
407 | # measure elapsed time
408 | batch_time.update(time.time() - end)
409 | end = time.time()
410 |
411 | if i % args.print_freq == 0:
412 | epoch_msg = progress.get_message(i)
413 | print(epoch_msg)
414 |
415 | # TODO: this should also be done with the ProgressMeter
416 | # print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
417 | # .format(top1=top1, top5=top5))
418 |
419 | epoch_msg = '----------- Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f} -----------'.format(top1=top1, top5=top5)
420 |
421 | print(epoch_msg)
422 |
423 | args.log_file.write(epoch_msg + "\n")
424 |
425 |
426 | return top1.avg
427 |
428 | def adjust_learning_rate(optimizer, epoch, args):
429 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
430 | lr = args.lr * (0.1 ** (epoch // 30))
431 | for param_group in optimizer.param_groups:
432 | param_group['lr'] = lr
433 |
434 |
435 | if __name__ == '__main__':
436 | main()
--------------------------------------------------------------------------------