├── .gitignore ├── figures ├── attentions.png └── training_curves.png ├── networks ├── __pycache__ │ └── __init__.cpython-38.pyc ├── cifar │ ├── __pycache__ │ │ ├── block.cpython-38.pyc │ │ ├── resnet.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── preresnet.cpython-38.pyc │ │ ├── mobilenetv2.cpython-38.pyc │ │ └── wideresnet.cpython-38.pyc │ ├── __init__.py │ ├── mobilenetv2.py │ ├── wideresnet.py │ ├── resnet.py │ ├── preresnet.py │ └── block.py ├── imagenet │ ├── __pycache__ │ │ ├── resnet.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ └── mobilenetv2.cpython-38.pyc │ ├── __init__.py │ ├── mobilenetv2.py │ └── resnet.py ├── attentions │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── se_module.cpython-38.pyc │ │ └── simam_module.cpython-38.pyc │ ├── __init__.py │ ├── se_module.py │ ├── simam_module.py │ └── cbam_module.py └── __init__.py ├── mmdetection ├── configs │ ├── mask_rcnn │ │ ├── mask_rcnn_r101simam_fpn_1x_coco.py │ │ └── mask_rcnn_r50simam_fpn_1x_coco.py │ ├── faster_rcnn │ │ ├── faster_rcnn_r101simam_fpn_1x_coco.py │ │ └── faster_rcnn_r50simam_fpn_1x_coco.py │ └── _base_ │ │ ├── schedules │ │ └── schedule_1x_lr0.01.py │ │ └── models │ │ ├── faster_rcnn_r50simam_fpn.py │ │ └── mask_rcnn_r50simam_fpn.py └── mmdet │ └── models │ └── backbones │ ├── attentions │ └── simam_module.py │ └── resnet_simam.py ├── checkpoint.py ├── util.py ├── README.md ├── main_cifar.py └── main_imagenet.py /.gitignore: -------------------------------------------------------------------------------- 1 | pretrained 2 | SimAM-previous.rar -------------------------------------------------------------------------------- /figures/attentions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/figures/attentions.png -------------------------------------------------------------------------------- /figures/training_curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/figures/training_curves.png -------------------------------------------------------------------------------- /networks/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /networks/cifar/__pycache__/block.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/cifar/__pycache__/block.cpython-38.pyc -------------------------------------------------------------------------------- /networks/cifar/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/cifar/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /networks/cifar/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/cifar/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /networks/cifar/__pycache__/preresnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/cifar/__pycache__/preresnet.cpython-38.pyc -------------------------------------------------------------------------------- /networks/imagenet/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/imagenet/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /networks/cifar/__pycache__/mobilenetv2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/cifar/__pycache__/mobilenetv2.cpython-38.pyc -------------------------------------------------------------------------------- /networks/cifar/__pycache__/wideresnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/cifar/__pycache__/wideresnet.cpython-38.pyc -------------------------------------------------------------------------------- /networks/imagenet/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/imagenet/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /networks/attentions/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/attentions/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /networks/attentions/__pycache__/se_module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/attentions/__pycache__/se_module.cpython-38.pyc -------------------------------------------------------------------------------- /networks/imagenet/__pycache__/mobilenetv2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/imagenet/__pycache__/mobilenetv2.cpython-38.pyc -------------------------------------------------------------------------------- /networks/attentions/__pycache__/simam_module.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZjjConan/SimAM/HEAD/networks/attentions/__pycache__/simam_module.cpython-38.pyc -------------------------------------------------------------------------------- /networks/attentions/__init__.py: -------------------------------------------------------------------------------- 1 | from ..import find_module_using_name 2 | 3 | 4 | 5 | def get_attention_module(attention_type="none"): 6 | 7 | return find_module_using_name(attention_type.lower()) -------------------------------------------------------------------------------- /mmdetection/configs/mask_rcnn/mask_rcnn_r101simam_fpn_1x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = './mask_rcnn_r50simam_fpn_1x_coco.py' 2 | model = dict(pretrained='checkpoints/simam-net/resnet101.pth.tar', backbone=dict(depth=101)) -------------------------------------------------------------------------------- /mmdetection/configs/faster_rcnn/faster_rcnn_r101simam_fpn_1x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = './faster_rcnn_r50simam_fpn_1x_coco.py' 2 | model = dict(pretrained='checkpoints/simam-net/resnet101.pth.tar', backbone=dict(depth=101)) 3 | -------------------------------------------------------------------------------- /mmdetection/configs/mask_rcnn/mask_rcnn_r50simam_fpn_1x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/mask_rcnn_r50simam_fpn.py', 3 | '../_base_/datasets/coco_instance.py', 4 | '../_base_/schedules/schedule_1x_lr0.01.py', '../_base_/default_runtime.py' 5 | ] 6 | 7 | find_unused_parameters=True -------------------------------------------------------------------------------- /mmdetection/configs/faster_rcnn/faster_rcnn_r50simam_fpn_1x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/faster_rcnn_r50simam_fpn.py', 3 | '../_base_/datasets/coco_detection.py', 4 | '../_base_/schedules/schedule_1x_lr0.01.py', '../_base_/default_runtime.py' 5 | ] 6 | 7 | find_unused_parameters=True 8 | -------------------------------------------------------------------------------- /mmdetection/configs/_base_/schedules/schedule_1x_lr0.01.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict( 6 | policy='step', 7 | warmup='linear', 8 | warmup_iters=500, 9 | warmup_ratio=0.001, 10 | step=[8, 11]) 11 | total_epochs = 12 12 | -------------------------------------------------------------------------------- /networks/attentions/se_module.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class se_module(nn.Module): 4 | def __init__(self, channel, reduction=16): 5 | super(se_module, self).__init__() 6 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 7 | self.fc = nn.Sequential( 8 | nn.Linear(channel, int(channel // reduction), bias=False), 9 | nn.ReLU(inplace=True), 10 | nn.Linear(int(channel // reduction), channel, bias=False), 11 | nn.Sigmoid() 12 | ) 13 | 14 | @staticmethod 15 | def get_module_name(): 16 | return "se" 17 | 18 | def forward(self, x): 19 | b, c, _, _ = x.size() 20 | y = self.avg_pool(x).view(b, c) 21 | y = self.fc(y).view(b, c, 1, 1) 22 | return x * y -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch.nn as nn 3 | 4 | def find_module_using_name(module_name): 5 | 6 | if module_name == "none": 7 | return None 8 | 9 | module_filename = "networks.attentions." + module_name + "_module" 10 | modellib = importlib.import_module(module_filename) 11 | 12 | module = None 13 | target_model_name = module_name + '_module' 14 | 15 | for name, cls in modellib.__dict__.items(): 16 | if name.lower() == target_model_name.lower() and issubclass(cls, nn.Module): 17 | module = cls 18 | 19 | if module is None: 20 | print("In %s.py, there should be a subclass of nn.Module with class name that matches %s in lowercase." % (module_filename, target_model_name)) 21 | exit(0) 22 | 23 | return module -------------------------------------------------------------------------------- /networks/attentions/simam_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class simam_module(torch.nn.Module): 6 | def __init__(self, channels = None, e_lambda = 1e-4): 7 | super(simam_module, self).__init__() 8 | 9 | self.activaton = nn.Sigmoid() 10 | self.e_lambda = e_lambda 11 | 12 | def __repr__(self): 13 | s = self.__class__.__name__ + '(' 14 | s += ('lambda=%f)' % self.e_lambda) 15 | return s 16 | 17 | @staticmethod 18 | def get_module_name(): 19 | return "simam" 20 | 21 | def forward(self, x): 22 | 23 | b, c, h, w = x.size() 24 | 25 | n = w * h - 1 26 | 27 | x_minus_mu_square = (x - x.mean(dim=[2,3], keepdim=True)).pow(2) 28 | y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2,3], keepdim=True) / n + self.e_lambda)) + 0.5 29 | 30 | return x * self.activaton(y) -------------------------------------------------------------------------------- /mmdetection/mmdet/models/backbones/attentions/simam_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class simam_module(torch.nn.Module): 6 | def __init__(self, channels = None, e_lambda = 1e-4): 7 | super(simam_module, self).__init__() 8 | 9 | self.activaton = nn.Sigmoid() 10 | self.e_lambda = e_lambda 11 | 12 | def __repr__(self): 13 | s = self.__class__.__name__ + '(' 14 | s += ('lambda=%f)' % self.e_lambda) 15 | return s 16 | 17 | @staticmethod 18 | def get_module_name(): 19 | return "simam" 20 | 21 | def forward(self, x): 22 | 23 | b, c, h, w = x.size() 24 | 25 | n = w * h - 1 26 | 27 | x_minus_mu_square = (x - x.mean(dim=[2,3], keepdim=True)).pow(2) 28 | y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2,3], keepdim=True) / n + self.e_lambda)) + 0.5 29 | 30 | return x * self.activaton(y) -------------------------------------------------------------------------------- /networks/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from ..attentions import get_attention_module 3 | from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152, resnext50_32x4d 4 | from .mobilenetv2 import mobilenet_v2 5 | from .resnet import resnet50d 6 | 7 | model_dict = { 8 | "mobilenet_v2": mobilenet_v2, 9 | "resnet18": resnet18, 10 | "resnet34": resnet34, 11 | "resnet50": resnet50, 12 | "resnet101": resnet101, 13 | "resnet152": resnet152, 14 | "resnet50d": resnet50d, 15 | "resnext50_32x4d": resnext50_32x4d 16 | } 17 | 18 | 19 | def create_net(args): 20 | net = None 21 | 22 | attention_module = get_attention_module(args.attention_type) 23 | 24 | # srm does not have any input parameters 25 | if args.attention_type == "se" or args.attention_type == "cbam": 26 | attention_module = functools.partial(attention_module, reduction=args.attention_param) 27 | elif args.attention_type == "simam": 28 | attention_module = functools.partial(attention_module, e_lambda=args.attention_param) 29 | 30 | kwargs = {} 31 | kwargs["num_classes"] = 1000 32 | kwargs["attention_module"] = attention_module 33 | 34 | net = model_dict[args.arch.lower()](**kwargs) 35 | 36 | return net -------------------------------------------------------------------------------- /checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import shutil 4 | 5 | def save_checkpoint(state, is_best, epoch, save_path='./'): 6 | print("=> saving checkpoint '{}'".format(epoch)) 7 | torch.save(state, os.path.join(save_path, 'checkpoint.pth.tar')) 8 | if(epoch % 10 == 0): 9 | torch.save(state, os.path.join(save_path, 'checkpoint_%03d.pth.tar' % epoch)) 10 | if is_best: 11 | if epoch >= 90: 12 | shutil.copyfile(os.path.join(save_path, 'checkpoint.pth.tar'), 13 | os.path.join(save_path, 'model_best_in_100_epochs.pth.tar')) 14 | else: 15 | shutil.copyfile(os.path.join(save_path, 'checkpoint.pth.tar'), 16 | os.path.join(save_path, 'model_best_in_090_epochs.pth.tar')) 17 | 18 | 19 | def load_checkpoint(args, model, optimizer=None, verbose=True): 20 | 21 | checkpoint = torch.load(args.resume) 22 | 23 | start_epoch = 0 24 | best_acc = 0 25 | 26 | if "epoch" in checkpoint: 27 | start_epoch = checkpoint['epoch'] 28 | 29 | if "best_acc" in checkpoint: 30 | best_acc = checkpoint['best_acc'] 31 | 32 | model.load_state_dict(checkpoint['state_dict'], False) 33 | 34 | if optimizer is not None and "optimizer" in checkpoint: 35 | optimizer.load_state_dict(checkpoint['optimizer']) 36 | 37 | for state in optimizer.state.values(): 38 | for k, v in state.items(): 39 | if isinstance(v, torch.Tensor): 40 | state[k] = v.to(args.device) 41 | 42 | if verbose: 43 | print("=> loading checkpoint '{}' (epoch {})" 44 | .format(args.resume, start_epoch)) 45 | 46 | return model, optimizer, best_acc, start_epoch -------------------------------------------------------------------------------- /networks/cifar/__init__.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from ..attentions import get_attention_module 3 | from .mobilenetv2 import MobileNetV2Wrapper 4 | from .resnet import ResNet20, ResNet32, ResNet56, ResNet110, ResNet164 5 | from .preresnet import PreResNet20, PreResNet32, PreResNet56, PreResNet110, PreResNet164 6 | from .wideresnet import WideResNet28x10, WideResNet40x10 7 | from .block import BasicBlock, BottleNect, PreBasicBlock, PreBottleNect, InvertedResidualBlock, WideBasicBlock 8 | 9 | 10 | model_dict = { 11 | "resnet20": ResNet20, 12 | "resnet32": ResNet32, 13 | "resnet56": ResNet56, 14 | 'resnet110': ResNet110, 15 | "resnet164": ResNet164, 16 | "preresnet20": PreResNet20, 17 | "preresnet32": PreResNet32, 18 | "preresnet56": PreResNet56, 19 | 'preresnet110': PreResNet110, 20 | "preresnet164": PreResNet164, 21 | "wideresnet28x10": WideResNet28x10, 22 | "wideresnet40x10": WideResNet40x10, 23 | "mobilenetv2": MobileNetV2Wrapper, 24 | } 25 | 26 | def get_block(block_type="basic"): 27 | 28 | block_type = block_type.lower() 29 | 30 | if block_type == "basic": 31 | b = BasicBlock 32 | elif block_type == "bottlenect": 33 | b = BottleNect 34 | elif block_type == "prebasic": 35 | b = PreBasicBlock 36 | elif block_type == "prebottlenect": 37 | b = PreBottleNect 38 | elif block_type == "ivrd": 39 | b = InvertedResidualBlock 40 | elif block_type == "widebasic": 41 | b = WideBasicBlock 42 | else: 43 | raise NotImplementedError('block [%s] is not found for dataset [%s]' % block_type) 44 | return b 45 | 46 | 47 | def create_net(args): 48 | net = None 49 | 50 | block_module = get_block(args.block_type) 51 | attention_module = get_attention_module(args.attention_type) 52 | 53 | if args.attention_type == "se" or args.attention_type == "cbam": 54 | attention_module = functools.partial(attention_module, reduction=args.attention_param) 55 | elif args.attention_type == "simam": 56 | attention_module = functools.partial(attention_module, e_lambda=args.attention_param) 57 | 58 | net = model_dict[args.arch.lower()]( 59 | num_class = args.num_class, 60 | block = block_module, 61 | attention_module = attention_module 62 | ) 63 | 64 | return net -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import fnmatch 4 | import torch 5 | 6 | 7 | class AverageMeter(object): 8 | """Computes and stores the average and current value""" 9 | def __init__(self, name, fmt=':f'): 10 | self.name = name 11 | self.fmt = fmt 12 | self.reset() 13 | 14 | def reset(self): 15 | self.val = 0 16 | self.avg = 0 17 | self.sum = 0 18 | self.count = 0 19 | 20 | def update(self, val, n=1): 21 | self.val = val 22 | self.sum += val * n 23 | self.count += n 24 | self.avg = self.sum / self.count 25 | 26 | def __str__(self): 27 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 28 | return fmtstr.format(**self.__dict__) 29 | 30 | 31 | class ProgressMeter(object): 32 | def __init__(self, num_batches, meters, prefix=""): 33 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 34 | self.meters = meters 35 | self.prefix = prefix 36 | 37 | def get_message(self, batch): 38 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 39 | entries += [str(meter) for meter in self.meters] 40 | return ('\t').join(entries) 41 | 42 | def _get_batch_fmtstr(self, num_batches): 43 | num_digits = len(str(num_batches // 1)) 44 | fmt = "{:" + str(num_digits) + "d}" 45 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 46 | 47 | 48 | def accuracy(output, target, topk=(1,)): 49 | """Computes the accuracy over the k top predictions for the specified values of k""" 50 | with torch.no_grad(): 51 | maxk = max(topk) 52 | batch_size = target.size(0) 53 | 54 | _, pred = output.topk(maxk, 1, True, True) 55 | pred = pred.t() 56 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 57 | 58 | res = [] 59 | for k in topk: 60 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 61 | res.append(correct_k.mul_(100.0 / batch_size)) 62 | return res 63 | 64 | 65 | def parse_gpus(gpu_ids): 66 | gpus = gpu_ids.split(',') 67 | gpu_ids = [] 68 | for g in gpus: 69 | g_int = int(g) 70 | if g_int >= 0: 71 | gpu_ids.append(g_int) 72 | if not gpu_ids: 73 | return None 74 | return gpu_ids 75 | -------------------------------------------------------------------------------- /networks/cifar/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | '''MobileNetV2 in PyTorch. 2 | 3 | See the paper "Inverted Residuals and Linear Bottlenecks: 4 | Mobile Networks for Classification, Detection and Segmentation" for more details. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from .block import InvertedResidualBlock 10 | 11 | class MobileNetV2(nn.Module): 12 | # (expansion, out_planes, num_blocks, stride) 13 | cfg = [(1, 16, 1, 1), 14 | (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10 15 | (6, 32, 3, 2), 16 | (6, 64, 4, 2), 17 | (6, 96, 3, 1), 18 | (6, 160, 3, 2), 19 | (6, 320, 1, 1)] 20 | 21 | def __init__(self, block, num_blocks=0, num_class=10): 22 | super(MobileNetV2, self).__init__() 23 | # NOTE: change conv1 stride 2 -> 1 for CIFAR10 24 | self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False) 25 | self.bn1 = nn.BatchNorm2d(32) 26 | self.layers = self._make_layers(block, in_planes=32) 27 | self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False) 28 | self.bn2 = nn.BatchNorm2d(1280) 29 | self.linear = nn.Linear(1280, num_class) 30 | 31 | def _make_layers(self, block, in_planes): 32 | layers = [] 33 | for expansion, out_planes, num_blocks, stride in self.cfg: 34 | strides = [stride] + [1]*(num_blocks-1) 35 | for stride in strides: 36 | layers.append(block(in_planes, out_planes, expansion, stride)) 37 | in_planes = out_planes 38 | return nn.Sequential(*layers) 39 | 40 | def forward(self, x): 41 | out = F.relu(self.bn1(self.conv1(x))) 42 | out = self.layers(out) 43 | out = F.relu(self.bn2(self.conv2(out))) 44 | # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10 45 | out = F.avg_pool2d(out, 4) 46 | out = out.view(out.size(0), -1) 47 | out = self.linear(out) 48 | return out 49 | 50 | 51 | def MobileNetV2Wrapper(num_class=10, block=None, attention_module=None): 52 | 53 | b = lambda in_planes, out_planes, expansion, stride: \ 54 | InvertedResidualBlock(in_planes, out_planes, expansion, stride, attention_module=attention_module) 55 | 56 | return MobileNetV2(b, num_blocks=0, num_class=num_class) 57 | 58 | -------------------------------------------------------------------------------- /networks/cifar/wideresnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import functools 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .block import WideBasicBlock 6 | 7 | 8 | class NetworkBlock(nn.Module): 9 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 10 | super(NetworkBlock, self).__init__() 11 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 12 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 13 | layers = [] 14 | for i in range(nb_layers): 15 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 16 | return nn.Sequential(*layers) 17 | def forward(self, x): 18 | return self.layer(x) 19 | 20 | class WideResNet(nn.Module): 21 | def __init__(self, block, depth, widen_factor=1, dropRate=0.0, num_class=10): 22 | super(WideResNet, self).__init__() 23 | nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor] 24 | assert (depth - 4) % 6 == 0, 'depth should be 6n+4' 25 | n = (depth - 4) // 6 26 | # block = WideBasicBlock 27 | # 1st conv before any network block 28 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False) 29 | # 1st block 30 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 31 | # 2nd block 32 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 33 | # 3rd block 34 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 35 | # global average pooling and classifier 36 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.fc = nn.Linear(nChannels[3], num_class) 39 | self.fc.bias.data.zero_() 40 | # print(self.fc.bias.data) 41 | self.nChannels = nChannels[3] 42 | 43 | for m in self.modules(): 44 | if isinstance(m, nn.Conv2d): 45 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 46 | m.weight.data.normal_(0, math.sqrt(2. / n)) 47 | elif isinstance(m, nn.BatchNorm2d): 48 | m.weight.data.fill_(1) 49 | m.bias.data.zero_() 50 | # elif isinstance(m, nn.Linear): 51 | # if hasattr(m, "bias"): 52 | # m.bias.data.zero_() 53 | 54 | def forward(self, x): 55 | out = self.conv1(x) 56 | out = self.block1(out) 57 | out = self.block2(out) 58 | out = self.block3(out) 59 | out = self.relu(self.bn1(out)) 60 | out = F.avg_pool2d(out, 8) 61 | out = out.view(-1, self.nChannels) 62 | return self.fc(out) 63 | 64 | 65 | def WideResNetWrapper(depth, widen_factor, dropRate=0, num_class=10, attention_module=None): 66 | 67 | b = lambda in_planes, planes, stride, dropRate: \ 68 | WideBasicBlock(in_planes, planes, stride, dropRate, attention_module=attention_module) 69 | 70 | return WideResNet(b, depth, widen_factor, dropRate, num_class=num_class) 71 | 72 | 73 | def WideResNet28x10(num_class=10, block=None, attention_module=None): 74 | 75 | return WideResNetWrapper( 76 | depth = 28, 77 | widen_factor = 10, 78 | dropRate = 0.3, 79 | num_class = num_class, 80 | attention_module = attention_module) 81 | 82 | 83 | def WideResNet40x10(num_class=10, block=None, attention_module=None): 84 | 85 | return WideResNetWrapper( 86 | depth = 40, 87 | widen_factor = 10, 88 | dropRate = 0.3, 89 | num_class = num_class, 90 | attention_module = attention_module) -------------------------------------------------------------------------------- /networks/cifar/resnet.py: -------------------------------------------------------------------------------- 1 | """PyTorch implementation of ResNet 2 | 3 | ResNet modifications written by Bichen Wu and Alvin Wan, based 4 | off of ResNet implementation by Kuang Liu. 5 | 6 | Reference: 7 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 8 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 9 | """ 10 | import functools 11 | import torch.nn as nn 12 | from .block import BasicBlock, BottleNect 13 | 14 | class ResNet(nn.Module): 15 | def __init__(self, block, num_blocks, num_class=10): 16 | super(ResNet, self).__init__() 17 | 18 | self.num_class = num_class 19 | self.in_channels = num_filters = 16 20 | 21 | self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False) 22 | self.bn1 = nn.BatchNorm2d(self.in_channels) 23 | 24 | self.relu = nn.ReLU(inplace=True) 25 | self.avgpool = nn.AdaptiveAvgPool2d(1) 26 | 27 | self.layer1 = self._make_layer(block, self.in_channels, num_blocks[0], stride=1) 28 | self.layer2 = self._make_layer(block, int(num_filters*2), num_blocks[1], stride=2) 29 | self.layer3 = self._make_layer(block, int(num_filters*4), num_blocks[2], stride=2) 30 | self.linear = nn.Linear(int(num_filters*4*block(16,16,1).EXPANSION), num_class) 31 | 32 | def _make_layer(self, block, ou_channels, num_blocks, stride): 33 | strides = [stride] + [1]*(num_blocks-1) 34 | layers = [] 35 | for stride in strides: 36 | layers.append(block(self.in_channels, ou_channels, stride)) 37 | self.in_channels = int(ou_channels * block(16,16,1).EXPANSION) 38 | return nn.Sequential(*layers) 39 | 40 | def forward(self, x): 41 | out = self.relu(self.bn1(self.conv1(x))) 42 | out = self.layer1(out) 43 | out = self.layer2(out) 44 | out = self.layer3(out) 45 | out = self.avgpool(out) 46 | out = out.view(out.size(0), -1) 47 | out = self.linear(out) 48 | return out 49 | 50 | 51 | def ResNetWrapper(num_blocks, num_class=10, block=None, attention_module=None): 52 | 53 | b = lambda in_planes, planes, stride: \ 54 | block(in_planes, planes, stride, attention_module=attention_module) 55 | 56 | return ResNet(b, num_blocks, num_class=num_class) 57 | 58 | 59 | def ResNet20(num_class=10, block=None, attention_module=None): 60 | return ResNetWrapper( 61 | num_blocks = [3, 3, 3], 62 | num_class = num_class, 63 | block = block, 64 | attention_module = attention_module) 65 | 66 | def ResNet32(num_class=10, block=None, attention_module=None): 67 | return ResNetWrapper( 68 | num_blocks = [5, 5, 5], 69 | num_class = num_class, 70 | block = block, 71 | attention_module = attention_module) 72 | 73 | 74 | def ResNet56(num_class=10, block=None, attention_module=None): 75 | 76 | if block == BasicBlock: 77 | n_blocks = [9, 9, 9] 78 | elif block == BottleNect: 79 | n_blocks = [6, 6, 6] 80 | 81 | return ResNetWrapper( 82 | num_blocks = n_blocks, 83 | num_class = num_class, 84 | block = block, 85 | attention_module = attention_module) 86 | 87 | def ResNet110(num_class=10, block=None, attention_module=None): 88 | 89 | if block == BasicBlock: 90 | n_blocks = [18, 18, 18] 91 | elif block == BottleNect: 92 | n_blocks = [12, 12, 12] 93 | 94 | return ResNetWrapper( 95 | num_blocks = n_blocks, 96 | num_class = num_class, 97 | block = block, 98 | attention_module = attention_module) 99 | 100 | 101 | def ResNet164(num_class=10, block=None, attention_module=None): 102 | 103 | if block == BasicBlock: 104 | n_blocks = [27, 27, 27] 105 | elif block == BottleNect: 106 | n_blocks = [18, 18, 18] 107 | 108 | return ResNetWrapper( 109 | num_blocks = n_blocks, 110 | num_class = num_class, 111 | block = block, 112 | attention_module = attention_module) -------------------------------------------------------------------------------- /mmdetection/configs/_base_/models/faster_rcnn_r50simam_fpn.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='FasterRCNN', 3 | pretrained='checkpoints/simam-net/resnet50.pth.tar', 4 | backbone=dict( 5 | type='ResNetAM', 6 | depth=50, 7 | frozen_stages=1, 8 | norm_eval=True, 9 | attention_type="simam", 10 | attention_param=0.1), 11 | neck=dict( 12 | type='FPN', 13 | in_channels=[256, 512, 1024, 2048], 14 | out_channels=256, 15 | num_outs=5), 16 | rpn_head=dict( 17 | type='RPNHead', 18 | in_channels=256, 19 | feat_channels=256, 20 | anchor_generator=dict( 21 | type='AnchorGenerator', 22 | scales=[8], 23 | ratios=[0.5, 1.0, 2.0], 24 | strides=[4, 8, 16, 32, 64]), 25 | bbox_coder=dict( 26 | type='DeltaXYWHBBoxCoder', 27 | target_means=[.0, .0, .0, .0], 28 | target_stds=[1.0, 1.0, 1.0, 1.0]), 29 | loss_cls=dict( 30 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 31 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 32 | roi_head=dict( 33 | type='StandardRoIHead', 34 | bbox_roi_extractor=dict( 35 | type='SingleRoIExtractor', 36 | roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), 37 | out_channels=256, 38 | featmap_strides=[4, 8, 16, 32]), 39 | bbox_head=dict( 40 | type='Shared2FCBBoxHead', 41 | in_channels=256, 42 | fc_out_channels=1024, 43 | roi_feat_size=7, 44 | num_classes=80, 45 | bbox_coder=dict( 46 | type='DeltaXYWHBBoxCoder', 47 | target_means=[0., 0., 0., 0.], 48 | target_stds=[0.1, 0.1, 0.2, 0.2]), 49 | reg_class_agnostic=False, 50 | loss_cls=dict( 51 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 52 | loss_bbox=dict(type='L1Loss', loss_weight=1.0))), 53 | # model training and testing settings 54 | train_cfg=dict( 55 | rpn=dict( 56 | assigner=dict( 57 | type='MaxIoUAssigner', 58 | pos_iou_thr=0.7, 59 | neg_iou_thr=0.3, 60 | min_pos_iou=0.3, 61 | match_low_quality=True, 62 | ignore_iof_thr=-1), 63 | sampler=dict( 64 | type='RandomSampler', 65 | num=256, 66 | pos_fraction=0.5, 67 | neg_pos_ub=-1, 68 | add_gt_as_proposals=False), 69 | allowed_border=-1, 70 | pos_weight=-1, 71 | debug=False), 72 | rpn_proposal=dict( 73 | nms_across_levels=False, 74 | nms_pre=2000, 75 | nms_post=1000, 76 | max_num=1000, 77 | nms_thr=0.7, 78 | min_bbox_size=0), 79 | rcnn=dict( 80 | assigner=dict( 81 | type='MaxIoUAssigner', 82 | pos_iou_thr=0.5, 83 | neg_iou_thr=0.5, 84 | min_pos_iou=0.5, 85 | match_low_quality=False, 86 | ignore_iof_thr=-1), 87 | sampler=dict( 88 | type='RandomSampler', 89 | num=512, 90 | pos_fraction=0.25, 91 | neg_pos_ub=-1, 92 | add_gt_as_proposals=True), 93 | pos_weight=-1, 94 | debug=False)), 95 | test_cfg=dict( 96 | rpn=dict( 97 | nms_across_levels=False, 98 | nms_pre=1000, 99 | nms_post=1000, 100 | max_num=1000, 101 | nms_thr=0.7, 102 | min_bbox_size=0), 103 | rcnn=dict( 104 | score_thr=0.05, 105 | nms=dict(type='nms', iou_threshold=0.5), 106 | max_per_img=100) 107 | # soft-nms is also supported for rcnn testing 108 | # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) 109 | )) 110 | -------------------------------------------------------------------------------- /networks/cifar/preresnet.py: -------------------------------------------------------------------------------- 1 | """PyTorch implementation of ResNet 2 | 3 | ResNet modifications written by Bichen Wu and Alvin Wan, based 4 | off of ResNet implementation by Kuang Liu. 5 | 6 | Reference: 7 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 8 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 9 | """ 10 | import functools 11 | import torch.nn as nn 12 | from .block import PreBasicBlock, PreBottleNect 13 | 14 | class PreResNet(nn.Module): 15 | def __init__(self, block, num_blocks, num_class=10): 16 | super(PreResNet, self).__init__() 17 | 18 | self.num_class = num_class 19 | self.in_channels = num_base_filters = 16 20 | 21 | self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False) 22 | 23 | self.layer1 = self._make_layer(block, self.in_channels, num_blocks[0], stride=1) 24 | self.layer2 = self._make_layer(block, int(num_base_filters*2), num_blocks[1], stride=2) 25 | self.layer3 = self._make_layer(block, int(num_base_filters*4), num_blocks[2], stride=2) 26 | 27 | self.bn = nn.BatchNorm2d(int(num_base_filters*4*block(16,16,1).EXPANSION)) 28 | 29 | self.linear = nn.Linear(int(num_base_filters*4*block(16,16,1).EXPANSION), num_class) 30 | 31 | self.relu = nn.ReLU(inplace=True) 32 | 33 | self.avgpool = nn.AdaptiveAvgPool2d(1) 34 | 35 | def _make_layer(self, block, ou_channels, num_blocks, stride): 36 | strides = [stride] + [1]*(num_blocks-1) 37 | layers = [] 38 | for stride in strides: 39 | layers.append(block(self.in_channels, ou_channels, stride)) 40 | self.in_channels = int(ou_channels * block(16,16,1).EXPANSION) 41 | return nn.Sequential(*layers) 42 | 43 | def forward(self, x): 44 | out = self.conv1(x) 45 | out = self.layer1(out) 46 | out = self.layer2(out) 47 | out = self.layer3(out) 48 | out = self.bn(out) 49 | out = self.relu(out) 50 | 51 | out = self.avgpool(out) 52 | out = out.view(out.size(0), -1) 53 | out = self.linear(out) 54 | return out 55 | 56 | 57 | def PreResNetWrapper(num_blocks, num_class=10, block=None, attention_module=None): 58 | 59 | b = lambda in_planes, planes, stride: \ 60 | block(in_planes, planes, stride, attention_module=attention_module) 61 | 62 | return PreResNet(b, num_blocks, num_class=num_class) 63 | 64 | 65 | def PreResNet20(num_class=10, block=None, attention_module=None): 66 | return PreResNetWrapper( 67 | num_blocks = [3, 3, 3], 68 | num_class = num_class, 69 | block = block, 70 | attention_module = attention_module) 71 | 72 | 73 | def PreResNet32(num_class=10, block=None, attention_module=None): 74 | return PreResNetWrapper( 75 | num_blocks = [5, 5, 5], 76 | num_class = num_class, 77 | block = block, 78 | attention_module = attention_module) 79 | 80 | 81 | def PreResNet56(num_class=10, block=None, attention_module=None): 82 | 83 | if block == PreBasicBlock: 84 | n_blocks = [9, 9, 9] 85 | elif block == PreBottleNect: 86 | n_blocks = [6, 6, 6] 87 | 88 | return PreResNetWrapper( 89 | num_blocks = n_blocks, 90 | num_class = num_class, 91 | block = block, 92 | attention_module = attention_module) 93 | 94 | def PreResNet110(num_class=10, block=None, attention_module=None): 95 | 96 | if block == PreBasicBlock: 97 | n_blocks = [18, 18, 18] 98 | elif block == PreBottleNect: 99 | n_blocks = [12, 12, 12] 100 | 101 | return PreResNetWrapper( 102 | num_blocks = n_blocks, 103 | num_class = num_class, 104 | block = block, 105 | attention_module = attention_module) 106 | 107 | 108 | def PreResNet164(num_class=10, block=None, attention_module=None): 109 | 110 | if block == PreBasicBlock: 111 | n_blocks = [27, 27, 27] 112 | elif block == PreBottleNect: 113 | n_blocks = [18, 18, 18] 114 | 115 | return PreResNetWrapper( 116 | num_blocks = n_blocks, 117 | num_class = num_class, 118 | block = block, 119 | attention_module = attention_module) -------------------------------------------------------------------------------- /networks/attentions/cbam_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class BasicConv(nn.Module): 7 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): 8 | super(BasicConv, self).__init__() 9 | self.out_channels = out_planes 10 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 11 | self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None 12 | self.relu = nn.ReLU() if relu else None 13 | 14 | def forward(self, x): 15 | x = self.conv(x) 16 | if self.bn is not None: 17 | x = self.bn(x) 18 | if self.relu is not None: 19 | x = self.relu(x) 20 | return x 21 | 22 | class Flatten(nn.Module): 23 | def forward(self, x): 24 | return x.view(x.size(0), -1) 25 | 26 | class ChannelGate(nn.Module): 27 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): 28 | super(ChannelGate, self).__init__() 29 | self.gate_channels = gate_channels 30 | self.mlp = nn.Sequential( 31 | Flatten(), 32 | nn.Linear(gate_channels, int(gate_channels // reduction_ratio)), 33 | nn.ReLU(), 34 | nn.Linear(int(gate_channels // reduction_ratio), gate_channels) 35 | ) 36 | self.pool_types = pool_types 37 | def forward(self, x): 38 | channel_att_sum = None 39 | for pool_type in self.pool_types: 40 | if pool_type=='avg': 41 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 42 | channel_att_raw = self.mlp( avg_pool ) 43 | elif pool_type=='max': 44 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 45 | channel_att_raw = self.mlp( max_pool ) 46 | elif pool_type=='lp': 47 | lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 48 | channel_att_raw = self.mlp( lp_pool ) 49 | elif pool_type=='lse': 50 | # LSE pool only 51 | lse_pool = logsumexp_2d(x) 52 | channel_att_raw = self.mlp( lse_pool ) 53 | 54 | if channel_att_sum is None: 55 | channel_att_sum = channel_att_raw 56 | else: 57 | channel_att_sum = channel_att_sum + channel_att_raw 58 | 59 | scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) 60 | return x * scale 61 | 62 | def logsumexp_2d(tensor): 63 | tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) 64 | s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) 65 | outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() 66 | return outputs 67 | 68 | class ChannelPool(nn.Module): 69 | def forward(self, x): 70 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) 71 | 72 | class SpatialGate(nn.Module): 73 | def __init__(self): 74 | super(SpatialGate, self).__init__() 75 | kernel_size = 7 76 | self.compress = ChannelPool() 77 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=int((kernel_size-1) // 2), relu=False) 78 | def forward(self, x): 79 | x_compress = self.compress(x) 80 | x_out = self.spatial(x_compress) 81 | scale = F.sigmoid(x_out) # broadcasting 82 | return x * scale 83 | 84 | class cbam_module(nn.Module): 85 | def __init__(self, gate_channels, reduction=16, pool_types=['avg', 'max'], no_spatial=False): 86 | super(cbam_module, self).__init__() 87 | self.ChannelGate = ChannelGate(gate_channels, reduction, pool_types) 88 | self.no_spatial=no_spatial 89 | if not no_spatial: 90 | self.SpatialGate = SpatialGate() 91 | 92 | @staticmethod 93 | def get_module_name(): 94 | return "cbam" 95 | 96 | def forward(self, x): 97 | x_out = self.ChannelGate(x) 98 | if not self.no_spatial: 99 | x_out = self.SpatialGate(x_out) 100 | return x_out -------------------------------------------------------------------------------- /mmdetection/configs/_base_/models/mask_rcnn_r50simam_fpn.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='MaskRCNN', 4 | pretrained='checkpoints/simam-net/resnet50.pth.tar', 5 | backbone=dict( 6 | type='ResNetAM', 7 | depth=50, 8 | frozen_stages=1, 9 | norm_eval=True, 10 | attention_type="simam", 11 | attention_param=0.1), 12 | neck=dict( 13 | type='FPN', 14 | in_channels=[256, 512, 1024, 2048], 15 | out_channels=256, 16 | num_outs=5), 17 | rpn_head=dict( 18 | type='RPNHead', 19 | in_channels=256, 20 | feat_channels=256, 21 | anchor_generator=dict( 22 | type='AnchorGenerator', 23 | scales=[8], 24 | ratios=[0.5, 1.0, 2.0], 25 | strides=[4, 8, 16, 32, 64]), 26 | bbox_coder=dict( 27 | type='DeltaXYWHBBoxCoder', 28 | target_means=[.0, .0, .0, .0], 29 | target_stds=[1.0, 1.0, 1.0, 1.0]), 30 | loss_cls=dict( 31 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 32 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 33 | roi_head=dict( 34 | type='StandardRoIHead', 35 | bbox_roi_extractor=dict( 36 | type='SingleRoIExtractor', 37 | roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), 38 | out_channels=256, 39 | featmap_strides=[4, 8, 16, 32]), 40 | bbox_head=dict( 41 | type='Shared2FCBBoxHead', 42 | in_channels=256, 43 | fc_out_channels=1024, 44 | roi_feat_size=7, 45 | num_classes=80, 46 | bbox_coder=dict( 47 | type='DeltaXYWHBBoxCoder', 48 | target_means=[0., 0., 0., 0.], 49 | target_stds=[0.1, 0.1, 0.2, 0.2]), 50 | reg_class_agnostic=False, 51 | loss_cls=dict( 52 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 53 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 54 | mask_roi_extractor=dict( 55 | type='SingleRoIExtractor', 56 | roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), 57 | out_channels=256, 58 | featmap_strides=[4, 8, 16, 32]), 59 | mask_head=dict( 60 | type='FCNMaskHead', 61 | num_convs=4, 62 | in_channels=256, 63 | conv_out_channels=256, 64 | num_classes=80, 65 | loss_mask=dict( 66 | type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), 67 | # model training and testing settings 68 | train_cfg=dict( 69 | rpn=dict( 70 | assigner=dict( 71 | type='MaxIoUAssigner', 72 | pos_iou_thr=0.7, 73 | neg_iou_thr=0.3, 74 | min_pos_iou=0.3, 75 | match_low_quality=True, 76 | ignore_iof_thr=-1), 77 | sampler=dict( 78 | type='RandomSampler', 79 | num=256, 80 | pos_fraction=0.5, 81 | neg_pos_ub=-1, 82 | add_gt_as_proposals=False), 83 | allowed_border=-1, 84 | pos_weight=-1, 85 | debug=False), 86 | rpn_proposal=dict( 87 | nms_across_levels=False, 88 | nms_pre=2000, 89 | nms_post=1000, 90 | max_num=1000, 91 | nms_thr=0.7, 92 | min_bbox_size=0), 93 | rcnn=dict( 94 | assigner=dict( 95 | type='MaxIoUAssigner', 96 | pos_iou_thr=0.5, 97 | neg_iou_thr=0.5, 98 | min_pos_iou=0.5, 99 | match_low_quality=True, 100 | ignore_iof_thr=-1), 101 | sampler=dict( 102 | type='RandomSampler', 103 | num=512, 104 | pos_fraction=0.25, 105 | neg_pos_ub=-1, 106 | add_gt_as_proposals=True), 107 | mask_size=28, 108 | pos_weight=-1, 109 | debug=False)), 110 | test_cfg=dict( 111 | rpn=dict( 112 | nms_across_levels=False, 113 | nms_pre=1000, 114 | nms_post=1000, 115 | max_num=1000, 116 | nms_thr=0.7, 117 | min_bbox_size=0), 118 | rcnn=dict( 119 | score_thr=0.05, 120 | nms=dict(type='nms', iou_threshold=0.5), 121 | max_per_img=100, 122 | mask_thr_binary=0.5))) 123 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## SimAM: A Simple, Parameter-Free Attention Module for Convolutional Neural Networks 2 | [**Lingxiao Yang**](https://zjjconan.github.io/), [Ru-Yuan Zhang](https://ruyuanzhang.github.io/), [Lida Li](https://github.com/lld533), [Xiaohua Xie](http://cse.sysu.edu.cn/content/2478) 3 | 4 | Abstract 5 | ---------- 6 | In this paper, we propose a conceptually simple but very effective attention module for Convolutional Neural Networks (ConvNets). In contrast to existing channel-wise and spatial-wise attention modules, our module instead infers 3-D attention weights for the feature map in a layer without adding parameters to the original networks. Specifically, we base on some well-known neuroscience theories and propose to optimize an energy function to find the importance of each neuron. We further derive a fast closed-form solution for the energy function, and show that the solution can be implemented in less than ten lines of code. Another advantage of the module is that most of the operators are selected based on the solution to the defined energy function, avoiding too many efforts for structure tuning. Quantitative evaluations on various visual tasks demonstrate that the proposed module is flexible and effective to improve the representation ability of many ConvNets. Our code is available at [Pytorch-SimAM](https://github.com/ZjjConan/SimAM). 7 | 8 | -------------------------------------------------- 9 | 10 | Our environments and toolkits 11 | ----------- 12 | 13 | - OS: Ubuntu 18.04.5 14 | - CUDA: 11.0 15 | - Python: 3.8.3 16 | - Toolkit: PyTorch 1.8.0 17 | - GPU: Quadro RTX 8000 (4x) 18 | - [thop](https://github.com/Lyken17/pytorch-OpCounter) 19 | 20 | 21 | Module 22 | ------ 23 | 24 | Our goal is to infer a 3-D attention weights (Figure (c)) with a given feature map, which is very different to previous works as shown in Figure (a) and (b). 25 | 26 |

27 | 28 |

29 | 30 | **SimAM (A pytorch-like implementation).** Detail of implementations, including the module and the network, can be found in ``networks`` in this repository. 31 | 32 | 33 | ```python 34 | class SimAM(nn.Module): 35 | # X: input feature [N, C, H, W] 36 | # lambda: coefficient λ in Eqn (5) 37 | def forward (X, lambda): 38 | # spatial size 39 | n = X.shape[2] * X.shape[3] - 1 40 | # square of (t - u) 41 | d = (X - X.mean(dim=[2,3])).pow(2) 42 | # d.sum() / n is channel variance 43 | v = d.sum(dim=[2,3]) / n 44 | # E_inv groups all importance of X 45 | E_inv = d / (4 * (v + lambda)) + 0.5 46 | # return attended features 47 | return X * sigmoid(E_inv) 48 | ``` 49 | 50 | Training and Validation Curves 51 | ---------- 52 | 53 |

54 | 55 |

56 | 57 | Experiments 58 | ---------- 59 | 60 | ### Training and Evaluation 61 | 62 | The following commands train models on ImageNet from scratch with 4 gpus. 63 | 64 | 65 | ``` 66 | # Training from scratch 67 | 68 | python main_imagenet.py {the path of ImageNet} --gpu 0,1,2,3 --epochs 100 -j 20 -a resnet18 69 | 70 | python main_imagenet.py {the path of ImageNet} --gpu 0,1,2,3 --epochs 100 -j 20 -a resnet18 71 | --attention_type simam --attention_param 0.1 72 | 73 | python main_imagenet.py {the path of ImageNet} --gpu 0,1,2,3 --epochs 150 -j 20 -a mobilenet_v2 74 | --attention_type simam --attention_param 0.1 --lr .05 --cos_lr --wd 4e-5 75 | ``` 76 | 77 | ``` 78 | # Evaluating the trained model 79 | 80 | python main_imagenet.py {the path of ImageNet} --gpu 0,1,2,3 -j 20 -a resnet18 -e 81 | --resume {the path of pretrained .pth} 82 | ``` 83 | 84 | ### ImageNet 85 | 86 | All the following models can be download from **[BaiduYunPan](https://pan.baidu.com/s/1i_imf4Ny4U9SkD5e9nhCJw)** (extract code: **25tp**) and **[Google Drive](https://drive.google.com/drive/folders/1rRT0UCPeRLPdTCJvv43hvAnGnS49nIWn?usp=sharing).** 87 | 88 | |Model |Parameters |FLOPs |Top-1(%) |Top-5(%)| 89 | |:---: |:----: |:---: |:------: |:------:| 90 | |SimAM-R18 |11.69 M |1.82 G |71.31 |89.88 | 91 | |SimAM-R34 |21.80 M |3.67 G |74.49 |92.02 | 92 | |SimAM-R50 |25.56 M |4.11 G |77.45 |93.66 | 93 | |SimAM-R101 |44.55 M |7.83 G |78.65 |94.11 | 94 | |SimAM-RX50 (32x4d) |25.03 M |4.26 G |78.00 |93.93 | 95 | |SimAM-MV2 |3.50 M |0.31 G |72.36 |90.74 | 96 | 97 | ### COCO Evaluation 98 | We use [mmdetection](https://github.com/open-mmlab/mmdetection) to train Faster RCNN and Mask RCNN for object detection and instance segmentation. If you want to run the following models, please firstly install `mmdetection` with their guide. And then put all `.py` in mmdetection of this repository to the corresponding folders. All the following models can be download from **[BaiduYunPan](https://pan.baidu.com/s/1NtMgu09vv0tEhb2PsXCsQQ)** (extract code: **ysrz**) and **[Google Drive](https://drive.google.com/drive/folders/1F8W3MY32crU6jUeV2sgc_4AQwqt_MvAp?usp=sharing).** 99 | 100 | #### Detection with Faster RCNN (FR for short) and Mask RCNN (MR for short) 101 | 102 | |Model |AP |AP_50 |AP_75|AP_S |AP_M |AP_L | 103 | |:----: |:----: |:---: |:--: |:----: |:---: |:--: | 104 | |FR-SimAM-R50 |39.2 |60.7 |40.8 |22.8 |43.0 |50.6 | 105 | |FR-SimAM-R101 |41.2 |62.4 |45.0 |24.0 |45.6 |52.8 | 106 | |MR-SimAM-R50 |39.8 |61.0 |43.4 |23.1 |43.7 |51.4 | 107 | |MR-SimAM-R101 |41.8 |62.8 |46.0 |24.8 |46.2 |53.9 | 108 | 109 | 110 | #### Instance Segmentation with Mask RCNN (MR for short) 111 | 112 | |Model |AP |AP_50 |AP_75|AP_S |AP_M |AP_L | 113 | |:----: |:----: |:---: |:--: |:----: |:---: |:--: | 114 | |MR-SimAM-R50 |36.0 |57.9 |38.2 |19.1 |39.7 |48.6 | 115 | |MR-SimAM-R101 |37.6 |59.5 |40.1 |20.5 |41.5 |50.8 | 116 | 117 | -------------------------------------------------------------------- 118 | 119 | 120 | Citation 121 | -------- 122 | If you find SimAM useful in your research, please consider citing: 123 | 124 | @InProceedings{pmlr-v139-yang21o, 125 | title = {SimAM: A Simple, Parameter-Free Attention Module for Convolutional Neural Networks}, 126 | author = {Yang, Lingxiao and Zhang, Ru-Yuan and Li, Lida and Xie, Xiaohua}, 127 | booktitle = {Proceedings of the 38th International Conference on Machine Learning}, 128 | pages = {11863--11874}, 129 | year = {2021}, 130 | editor = {Meila, Marina and Zhang, Tong}, 131 | volume = {139}, 132 | series = {Proceedings of Machine Learning Research}, 133 | month = {18--24 Jul}, 134 | publisher = {PMLR}, 135 | pdf = {http://proceedings.mlr.press/v139/yang21o/yang21o.pdf}, 136 | url = {http://proceedings.mlr.press/v139/yang21o.html} 137 | } 138 | 139 | ## Contact Information 140 | 141 | If you have any suggestion or question, you can contact us by: lingxiao.yang717@gmail.com. Thanks for your attention! 142 | -------------------------------------------------------------------------------- /networks/imagenet/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from torch import nn 3 | from torch import Tensor 4 | # from .utils import load_state_dict_from_url 5 | from typing import Callable, Any, Optional, List 6 | 7 | 8 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 9 | 10 | 11 | # model_urls = { 12 | # 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 13 | # } 14 | 15 | 16 | def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: 17 | """ 18 | This function is taken from the original tf repo. 19 | It ensures that all layers have a channel number that is divisible by 8 20 | It can be seen here: 21 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 22 | :param v: 23 | :param divisor: 24 | :param min_value: 25 | :return: 26 | """ 27 | if min_value is None: 28 | min_value = divisor 29 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 30 | # Make sure that round down does not go down by more than 10%. 31 | if new_v < 0.9 * v: 32 | new_v += divisor 33 | return new_v 34 | 35 | 36 | class ConvBNActivation(nn.Sequential): 37 | def __init__( 38 | self, 39 | in_planes: int, 40 | out_planes: int, 41 | kernel_size: int = 3, 42 | stride: int = 1, 43 | groups: int = 1, 44 | norm_layer: Optional[Callable[..., nn.Module]] = None, 45 | activation_layer: Optional[Callable[..., nn.Module]] = None, 46 | attention_module: Optional[Callable[..., nn.Module]] = None, 47 | ) -> None: 48 | padding = (kernel_size - 1) // 2 49 | if norm_layer is None: 50 | norm_layer = nn.BatchNorm2d 51 | if activation_layer is None: 52 | activation_layer = nn.ReLU6 53 | if attention_module is not None: 54 | if type(attention_module) == functools.partial: 55 | module_name = attention_module.func.get_module_name() 56 | else: 57 | module_name = attention_module.get_module_name() 58 | 59 | if module_name == "simam": 60 | super(ConvBNReLU, self).__init__( 61 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 62 | attention_module(out_planes), 63 | norm_layer(out_planes), 64 | activation_layer(inplace=True) 65 | ) 66 | else: 67 | super(ConvBNReLU, self).__init__( 68 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 69 | norm_layer(out_planes), 70 | activation_layer(inplace=True) 71 | ) 72 | else: 73 | super(ConvBNReLU, self).__init__( 74 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 75 | norm_layer(out_planes), 76 | activation_layer(inplace=True) 77 | ) 78 | 79 | 80 | # necessary for backwards compatibility 81 | ConvBNReLU = ConvBNActivation 82 | 83 | 84 | class InvertedResidual(nn.Module): 85 | def __init__( 86 | self, 87 | inp: int, 88 | oup: int, 89 | stride: int, 90 | expand_ratio: int, 91 | norm_layer: Optional[Callable[..., nn.Module]] = None, 92 | attention_module: Optional[Callable[..., nn.Module]] = None 93 | ) -> None: 94 | super(InvertedResidual, self).__init__() 95 | self.stride = stride 96 | assert stride in [1, 2] 97 | 98 | if norm_layer is None: 99 | norm_layer = nn.BatchNorm2d 100 | 101 | hidden_dim = int(round(inp * expand_ratio)) 102 | self.use_res_connect = self.stride == 1 and inp == oup 103 | 104 | layers: List[nn.Module] = [] 105 | if expand_ratio != 1: 106 | # pw 107 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) 108 | layers.extend([ 109 | # dw 110 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer, attention_module=attention_module), 111 | # pw-linear 112 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 113 | norm_layer(oup), 114 | ]) 115 | 116 | if attention_module is not None: 117 | if type(attention_module) == functools.partial: 118 | module_name = attention_module.func.get_module_name() 119 | else: 120 | module_name = attention_module.get_module_name() 121 | 122 | if module_name != "simam": 123 | # print(attention_module) 124 | layers.append(attention_module(oup)) 125 | 126 | self.conv = nn.Sequential(*layers) 127 | 128 | def forward(self, x: Tensor) -> Tensor: 129 | if self.use_res_connect: 130 | return x + self.conv(x) 131 | else: 132 | return self.conv(x) 133 | 134 | 135 | class MobileNetV2(nn.Module): 136 | def __init__( 137 | self, 138 | num_classes: int = 1000, 139 | width_mult: float = 1.0, 140 | inverted_residual_setting: Optional[List[List[int]]] = None, 141 | round_nearest: int = 8, 142 | block: Optional[Callable[..., nn.Module]] = None, 143 | norm_layer: Optional[Callable[..., nn.Module]] = None, 144 | attention_module: Optional[Callable[..., nn.Module]] = None 145 | ) -> None: 146 | """ 147 | MobileNet V2 main class 148 | 149 | Args: 150 | num_classes (int): Number of classes 151 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 152 | inverted_residual_setting: Network structure 153 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 154 | Set to 1 to turn off rounding 155 | block: Module specifying inverted residual building block for mobilenet 156 | norm_layer: Module specifying the normalization layer to use 157 | 158 | attention_module: Module specifying the attention layer to use 159 | """ 160 | super(MobileNetV2, self).__init__() 161 | 162 | if block is None: 163 | block = InvertedResidual 164 | 165 | if norm_layer is None: 166 | norm_layer = nn.BatchNorm2d 167 | 168 | input_channel = 32 169 | last_channel = 1280 170 | 171 | if inverted_residual_setting is None: 172 | inverted_residual_setting = [ 173 | # t, c, n, s 174 | [1, 16, 1, 1], 175 | [6, 24, 2, 2], 176 | [6, 32, 3, 2], 177 | [6, 64, 4, 2], 178 | [6, 96, 3, 1], 179 | [6, 160, 3, 2], 180 | [6, 320, 1, 1], 181 | ] 182 | 183 | # only check the first element, assuming user knows t,c,n,s are required 184 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 185 | raise ValueError("inverted_residual_setting should be non-empty " 186 | "or a 4-element list, got {}".format(inverted_residual_setting)) 187 | 188 | # building first layer 189 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 190 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 191 | features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)] 192 | # building inverted residual blocks 193 | for t, c, n, s in inverted_residual_setting: 194 | output_channel = _make_divisible(c * width_mult, round_nearest) 195 | for i in range(n): 196 | stride = s if i == 0 else 1 197 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer, attention_module=attention_module)) 198 | input_channel = output_channel 199 | # building last several layers 200 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer)) 201 | # make it nn.Sequential 202 | self.features = nn.Sequential(*features) 203 | 204 | # building classifier 205 | self.classifier = nn.Sequential( 206 | # nn.Dropout(0.2), 207 | nn.Linear(self.last_channel, num_classes), 208 | ) 209 | 210 | # weight initialization 211 | for m in self.modules(): 212 | if isinstance(m, nn.Conv2d): 213 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 214 | if m.bias is not None: 215 | nn.init.zeros_(m.bias) 216 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 217 | nn.init.ones_(m.weight) 218 | nn.init.zeros_(m.bias) 219 | elif isinstance(m, nn.Linear): 220 | nn.init.normal_(m.weight, 0, 0.01) 221 | if m.bias is not None: 222 | nn.init.zeros_(m.bias) 223 | 224 | def _forward_impl(self, x: Tensor) -> Tensor: 225 | # This exists since TorchScript doesn't support inheritance, so the superclass method 226 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 227 | x = self.features(x) 228 | # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0] 229 | x = nn.functional.adaptive_avg_pool2d(x, (1, 1)).reshape(x.shape[0], -1) 230 | x = self.classifier(x) 231 | return x 232 | 233 | def forward(self, x: Tensor) -> Tensor: 234 | return self._forward_impl(x) 235 | 236 | 237 | def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV2: 238 | """ 239 | Constructs a MobileNetV2 architecture from 240 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 241 | 242 | Args: 243 | pretrained (bool): If True, returns a model pre-trained on ImageNet 244 | progress (bool): If True, displays a progress bar of the download to stderr 245 | """ 246 | model = MobileNetV2(**kwargs) 247 | # if pretrained: 248 | # state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 249 | # progress=progress) 250 | # model.load_state_dict(state_dict) 251 | return model -------------------------------------------------------------------------------- /networks/cifar/block.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def conv3x3(in_channels, ou_channels, stride=1): 8 | return nn.Conv2d(in_channels, ou_channels, kernel_size=3, stride=stride, padding=1, bias=False) 9 | 10 | 11 | def conv1x1(in_channels, ou_channels, stride=1): 12 | return nn.Conv2d(in_channels, ou_channels, kernel_size=1, stride=stride, padding=0, bias=False) 13 | 14 | 15 | # Basic Block in ResNet for CIFAR 16 | class BasicBlock(nn.Module): 17 | 18 | EXPANSION = 1 19 | 20 | def __init__(self, in_channels, ou_channels, stride=1, attention_module=None): 21 | super(BasicBlock, self).__init__() 22 | 23 | self.relu = nn.ReLU(inplace=True) 24 | 25 | self.conv1 = conv3x3(in_channels, ou_channels, stride=stride) 26 | self.bn1 = nn.BatchNorm2d(ou_channels) 27 | 28 | self.conv2 = conv3x3(ou_channels, ou_channels * self.EXPANSION, stride=1) 29 | self.bn2 = nn.BatchNorm2d(ou_channels * self.EXPANSION) 30 | 31 | if attention_module is not None: 32 | if type(attention_module) == functools.partial: 33 | module_name = attention_module.func.get_module_name() 34 | else: 35 | module_name = attention_module.get_module_name() 36 | 37 | 38 | if module_name == "simam": 39 | self.conv2 = nn.Sequential( 40 | self.conv2, 41 | attention_module(ou_channels * self.EXPANSION) 42 | ) 43 | else: 44 | self.bn2 = nn.Sequential( 45 | self.bn2, 46 | attention_module(ou_channels * self.EXPANSION) 47 | ) 48 | 49 | self.shortcut = nn.Sequential() 50 | if stride != 1 or in_channels != ou_channels * self.EXPANSION: 51 | self.shortcut = nn.Sequential( 52 | conv1x1(in_channels, ou_channels * self.EXPANSION, stride=stride), 53 | nn.BatchNorm2d(ou_channels * self.EXPANSION) 54 | ) 55 | 56 | def forward(self, x): 57 | out = self.conv1(x) 58 | out = self.bn1(out) 59 | out = self.relu(out) 60 | 61 | out = self.conv2(out) 62 | out = self.bn2(out) 63 | 64 | out += self.shortcut(x) 65 | 66 | return self.relu(out) 67 | 68 | # Bottlenect in ResNet for CIFAR 69 | class BottleNect(nn.Module): 70 | 71 | EXPANSION = 4 72 | 73 | def __init__(self, in_channels, ou_channels, stride=1, attention_module=None): 74 | super(BottleNect, self).__init__() 75 | 76 | self.relu = nn.ReLU(inplace=True) 77 | 78 | self.conv1 = conv1x1(in_channels, ou_channels, stride=1) 79 | self.bn1 = nn.BatchNorm2d(ou_channels) 80 | 81 | self.conv2 = conv3x3(ou_channels, ou_channels, stride=stride) 82 | self.bn2 = nn.BatchNorm2d(ou_channels) 83 | 84 | self.conv3 = conv1x1(ou_channels, ou_channels * self.EXPANSION, stride=1) 85 | self.bn3 = nn.BatchNorm2d(ou_channels * self.EXPANSION) 86 | 87 | if attention_module is not None: 88 | if type(attention_module) == functools.partial: 89 | module_name = attention_module.func.get_module_name() 90 | else: 91 | module_name = attention_module.get_module_name() 92 | 93 | if module_name == "simam": 94 | self.conv2 = nn.Sequential( 95 | self.conv2, 96 | attention_module(ou_channels * self.EXPANSION) 97 | ) 98 | else: 99 | self.bn3 = nn.Sequential( 100 | self.bn3, 101 | attention_module(ou_channels * self.EXPANSION) 102 | ) 103 | 104 | self.shortcut = nn.Sequential() 105 | if stride != 1 or in_channels != ou_channels * self.EXPANSION: 106 | self.shortcut = nn.Sequential( 107 | conv1x1(in_channels, ou_channels * self.EXPANSION, stride=stride), 108 | nn.BatchNorm2d(ou_channels * self.EXPANSION) 109 | ) 110 | 111 | def forward(self, x): 112 | out = self.conv1(x) 113 | out = self.bn1(out) 114 | out = self.relu(out) 115 | 116 | out = self.conv2(out) 117 | out = self.bn2(out) 118 | out = self.relu(out) 119 | 120 | out = self.conv3(out) 121 | out = self.bn3(out) 122 | 123 | out += self.shortcut(x) 124 | 125 | return self.relu(out) 126 | 127 | 128 | 129 | # PreBasic Block in ResNet for CIFAR 130 | class PreBasicBlock(nn.Module): 131 | 132 | EXPANSION = 1 133 | 134 | def __init__(self, in_channels, ou_channels, stride=1, attention_module=None): 135 | super(PreBasicBlock, self).__init__() 136 | 137 | self.relu = nn.ReLU(inplace=True) 138 | 139 | self.bn1 = nn.BatchNorm2d(in_channels) 140 | self.conv1 = conv3x3(in_channels, ou_channels, stride=stride) 141 | 142 | self.bn2 = nn.BatchNorm2d(ou_channels) 143 | self.conv2 = conv3x3(ou_channels, ou_channels * self.EXPANSION, stride=1) 144 | 145 | if attention_module is not None: 146 | self.conv2 = nn.Sequential( 147 | self.conv2, 148 | attention_module(ou_channels * self.EXPANSION) 149 | ) 150 | 151 | self.shortcut = nn.Sequential() 152 | if stride != 1 or in_channels != ou_channels * self.EXPANSION: 153 | self.shortcut = nn.Sequential( 154 | conv1x1(in_channels, ou_channels * self.EXPANSION, stride=stride) 155 | ) 156 | 157 | def forward(self, x): 158 | out = self.bn1(x) 159 | out = self.relu(out) 160 | out = self.conv1(out) 161 | 162 | out = self.bn2(out) 163 | out = self.relu(out) 164 | out = self.conv2(out) 165 | 166 | out = out + self.shortcut(x) 167 | return out 168 | 169 | 170 | # Bottlenect in ResNet for CIFAR 171 | class PreBottleNect(nn.Module): 172 | 173 | EXPANSION = 4 174 | 175 | def __init__(self, in_channels, ou_channels, stride=1, attention_module=None): 176 | super(PreBottleNect, self).__init__() 177 | 178 | self.relu = nn.ReLU(inplace=True) 179 | 180 | self.bn1 = nn.BatchNorm2d(in_channels) 181 | self.conv1 = conv1x1(in_channels, ou_channels, stride=1) 182 | 183 | self.bn2 = nn.BatchNorm2d(ou_channels) 184 | self.conv2 = conv3x3(ou_channels, ou_channels, stride=stride) 185 | 186 | self.bn3 = nn.BatchNorm2d(ou_channels) 187 | self.conv3 = conv1x1(ou_channels, ou_channels * self.EXPANSION, stride=1) 188 | 189 | if attention_module is not None: 190 | if type(attention_module) == functools.partial: 191 | module_name = attention_module.func.get_module_name() 192 | else: 193 | module_name = attention_module.get_module_name() 194 | 195 | if module_name == "simam": 196 | self.conv2 = nn.Sequential( 197 | self.conv2, 198 | attention_module(ou_channels * self.EXPANSION) 199 | ) 200 | else: 201 | self.conv3 = nn.Sequential( 202 | self.conv3, 203 | attention_module(ou_channels * self.EXPANSION) 204 | ) 205 | 206 | self.shortcut = nn.Sequential() 207 | if stride != 1 or in_channels != ou_channels * self.EXPANSION: 208 | self.shortcut = nn.Sequential( 209 | conv1x1(in_channels, ou_channels * self.EXPANSION, stride=stride) 210 | ) 211 | 212 | def forward(self, x): 213 | out = self.bn1(x) 214 | out = self.relu(out) 215 | out = self.conv1(out) 216 | 217 | out = self.bn2(out) 218 | out = self.relu(out) 219 | out = self.conv2(out) 220 | 221 | out = self.bn3(out) 222 | out = self.relu(out) 223 | out = self.conv3(out) 224 | 225 | out = out + self.shortcut(x) 226 | 227 | return out 228 | 229 | 230 | class WideBasicBlock(nn.Module): 231 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0, attention_module=None): 232 | super(WideBasicBlock, self).__init__() 233 | self.bn1 = nn.BatchNorm2d(in_planes) 234 | self.relu1 = nn.ReLU(inplace=True) 235 | self.conv1 = conv3x3(in_planes, out_planes, stride=stride) 236 | 237 | self.bn2 = nn.BatchNorm2d(out_planes) 238 | self.relu2 = nn.ReLU(inplace=True) 239 | self.conv2 = conv3x3(out_planes, out_planes, stride=1) 240 | 241 | if attention_module is not None: 242 | self.conv2 = nn.Sequential( 243 | self.conv2, 244 | attention_module(out_planes) 245 | ) 246 | self.droprate = dropRate 247 | self.equalInOut = (in_planes == out_planes) 248 | self.convShortcut = (not self.equalInOut) and conv1x1(in_planes, out_planes, stride=stride) or None 249 | def forward(self, x): 250 | if not self.equalInOut: 251 | x = self.relu1(self.bn1(x)) 252 | else: 253 | out = self.relu1(self.bn1(x)) 254 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 255 | if self.droprate > 0: 256 | out = F.dropout(out, p=self.droprate, training=self.training) 257 | out = self.conv2(out) 258 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 259 | 260 | 261 | class InvertedResidualBlock(nn.Module): 262 | '''expand + depthwise + pointwise''' 263 | def __init__(self, in_planes, out_planes, expansion, stride, attention_module=None): 264 | super(InvertedResidualBlock, self).__init__() 265 | self.stride = stride 266 | 267 | planes = expansion * in_planes 268 | self.conv1 = conv1x1(in_planes, planes, stride=1) 269 | self.bn1 = nn.BatchNorm2d(planes) 270 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False) 271 | self.bn2 = nn.BatchNorm2d(planes) 272 | self.conv3 = conv1x1(planes, out_planes, stride=1) 273 | self.bn3 = nn.BatchNorm2d(out_planes) 274 | 275 | self.relu = nn.ReLU(inplace=True) 276 | 277 | if attention_module is not None: 278 | if type(attention_module) == functools.partial: 279 | module_name = attention_module.func.get_module_name() 280 | else: 281 | module_name = attention_module.get_module_name() 282 | 283 | if module_name == "simam": 284 | self.conv2 = nn.Sequential( 285 | self.conv2, 286 | attention_module(planes) 287 | ) 288 | else: 289 | self.bn3 = nn.Sequential( 290 | self.bn3, 291 | attention_module(out_planes) 292 | ) 293 | 294 | self.shortcut = nn.Sequential() 295 | if stride == 1 and in_planes != out_planes: 296 | self.shortcut = nn.Sequential( 297 | conv1x1(in_planes, out_planes, stride=1), 298 | nn.BatchNorm2d(out_planes), 299 | ) 300 | 301 | def forward(self, x): 302 | out = self.conv1(x) 303 | out = self.bn1(out) 304 | out = self.relu(out) 305 | 306 | out = self.conv2(out) 307 | out = self.bn2(out) 308 | out = self.relu(out) 309 | 310 | out = self.conv3(out) 311 | out = self.bn3(out) 312 | 313 | out = out + self.shortcut(x) if self.stride == 1 else out 314 | 315 | return out -------------------------------------------------------------------------------- /mmdetection/mmdet/models/backbones/resnet_simam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import functools 4 | 5 | from mmcv.cnn import constant_init, kaiming_init 6 | from mmcv.runner import load_checkpoint 7 | from mmdet.utils import get_root_logger 8 | from torch.nn.modules.batchnorm import _BatchNorm 9 | from ..builder import BACKBONES 10 | from .attentions import simam_module 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=dilation, groups=groups, bias=False, dilation=dilation) 17 | 18 | 19 | def conv1x1(in_planes, out_planes, stride=1): 20 | """1x1 convolution""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, 28 | inplanes, 29 | planes, 30 | stride=1, 31 | downsample=None, 32 | groups=1, 33 | base_width=64, 34 | dilation=1, 35 | norm_layer=None, 36 | attention_module=None): 37 | 38 | super(BasicBlock, self).__init__() 39 | if norm_layer is None: 40 | norm_layer = nn.BatchNorm2d 41 | if groups != 1 or base_width != 64: 42 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 43 | if dilation > 1: 44 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 45 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 46 | self.conv1 = conv3x3(inplanes, planes, stride) 47 | self.bn1 = norm_layer(planes) 48 | self.relu = nn.ReLU(inplace=True) 49 | self.conv2 = conv3x3(planes, planes) 50 | self.bn2 = norm_layer(planes) 51 | self.downsample = downsample 52 | self.stride = stride 53 | 54 | if attention_module is not None: 55 | if type(attention_module) == functools.partial: 56 | module_name = attention_module.func.get_module_name() 57 | else: 58 | module_name = attention_module.get_module_name() 59 | 60 | if module_name == "simam": 61 | self.conv2 = nn.Sequential( 62 | self.conv2, 63 | attention_module(planes) 64 | ) 65 | else: 66 | self.bn2 = nn.Sequential( 67 | self.bn2, 68 | attention_module(planes) 69 | ) 70 | 71 | def forward(self, x): 72 | identity = x 73 | 74 | out = self.conv1(x) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv2(out) 79 | out = self.bn2(out) 80 | 81 | if self.downsample is not None: 82 | identity = self.downsample(x) 83 | 84 | out += identity 85 | out = self.relu(out) 86 | 87 | return out 88 | 89 | 90 | class Bottleneck(nn.Module): 91 | 92 | expansion = 4 93 | 94 | def __init__(self, 95 | inplanes, 96 | planes, 97 | stride=1, 98 | downsample=None, 99 | groups=1, 100 | base_width=64, 101 | dilation=1, 102 | norm_layer=None, 103 | attention_module=None): 104 | 105 | super(Bottleneck, self).__init__() 106 | if norm_layer is None: 107 | norm_layer = nn.BatchNorm2d 108 | width = int(planes * (base_width / 64.)) * groups 109 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 110 | self.conv1 = conv1x1(inplanes, width) 111 | self.bn1 = norm_layer(width) 112 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 113 | self.bn2 = norm_layer(width) 114 | self.conv3 = conv1x1(width, planes * self.expansion) 115 | self.bn3 = norm_layer(planes * self.expansion) 116 | self.relu = nn.ReLU(inplace=True) 117 | self.downsample = downsample 118 | self.stride = stride 119 | 120 | if attention_module is not None: 121 | if type(attention_module) == functools.partial: 122 | module_name = attention_module.func.get_module_name() 123 | else: 124 | module_name = attention_module.get_module_name() 125 | 126 | if module_name == "simam": 127 | self.conv2 = nn.Sequential( 128 | self.conv2, 129 | attention_module(width) 130 | ) 131 | else: 132 | self.bn3 = nn.Sequential( 133 | self.bn3, 134 | attention_module(planes * self.expansion) 135 | ) 136 | 137 | def forward(self, x): 138 | identity = x 139 | 140 | out = self.conv1(x) 141 | out = self.bn1(out) 142 | out = self.relu(out) 143 | 144 | out = self.conv2(out) 145 | out = self.bn2(out) 146 | out = self.relu(out) 147 | 148 | out = self.conv3(out) 149 | out = self.bn3(out) 150 | 151 | if self.downsample is not None: 152 | identity = self.downsample(x) 153 | 154 | out += identity 155 | out = self.relu(out) 156 | 157 | return out 158 | 159 | 160 | @BACKBONES.register_module() 161 | class ResNetAM(nn.Module): 162 | 163 | arch_settings = { 164 | 18: (BasicBlock, (2, 2, 2, 2)), 165 | 34: (BasicBlock, (3, 4, 6, 3)), 166 | 50: (Bottleneck, (3, 4, 6, 3)), 167 | 101: (Bottleneck, (3, 4, 23, 3)), 168 | 152: (Bottleneck, (3, 8, 36, 3)) 169 | } 170 | 171 | def __init__(self, 172 | depth, 173 | groups=1, 174 | width_per_group=64, 175 | replace_stride_with_dilation=None, 176 | norm_layer=None, 177 | norm_eval=True, 178 | frozen_stages=-1, 179 | attention_type="none", 180 | attention_param=None, 181 | zero_init_residual=False): 182 | super(ResNetAM, self).__init__() 183 | if depth not in self.arch_settings: 184 | raise KeyError(f'invalid depth {depth} for resnet') 185 | 186 | if norm_layer is None: 187 | norm_layer = nn.BatchNorm2d 188 | self._norm_layer = norm_layer 189 | 190 | self.inplanes = 64 191 | self.dilation = 1 192 | self.norm_eval = norm_eval 193 | self.frozen_stages = frozen_stages 194 | self.zero_init_residual = zero_init_residual 195 | block, stage_blocks = self.arch_settings[depth] 196 | 197 | if attention_type == "simam": 198 | attention_module = functools.partial(simam_module, e_lambda=attention_param) 199 | else: 200 | attention_module = None 201 | 202 | if replace_stride_with_dilation is None: 203 | # each element in the tuple indicates if we should replace 204 | # the 2x2 stride with a dilated convolution instead 205 | replace_stride_with_dilation = [False, False, False] 206 | if len(replace_stride_with_dilation) != 3: 207 | raise ValueError("replace_stride_with_dilation should be None " 208 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 209 | self.groups = groups 210 | self.base_width = width_per_group 211 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 212 | bias=False) 213 | self.bn1 = norm_layer(self.inplanes) 214 | self.relu1 = nn.ReLU(inplace=True) 215 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 216 | 217 | self.layer1 = self._make_layer(block, 64, stage_blocks[0], 218 | attention_module=attention_module) 219 | 220 | self.layer2 = self._make_layer(block, 128, stage_blocks[1], stride=2, 221 | dilate=replace_stride_with_dilation[0], 222 | attention_module=attention_module) 223 | 224 | self.layer3 = self._make_layer(block, 256, stage_blocks[2], stride=2, 225 | dilate=replace_stride_with_dilation[1], 226 | attention_module=attention_module) 227 | 228 | self.layer4 = self._make_layer(block, 512, stage_blocks[3], stride=2, 229 | dilate=replace_stride_with_dilation[2], 230 | attention_module=attention_module) 231 | 232 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 233 | self.fc = nn.Linear(512 * block.expansion, 1000) 234 | 235 | 236 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False, attention_module=None): 237 | norm_layer = self._norm_layer 238 | downsample = None 239 | previous_dilation = self.dilation 240 | if dilate: 241 | self.dilation *= stride 242 | stride = 1 243 | if stride != 1 or self.inplanes != planes * block.expansion: 244 | downsample = nn.Sequential( 245 | conv1x1(self.inplanes, planes * block.expansion, stride), 246 | norm_layer(planes * block.expansion), 247 | ) 248 | 249 | layers = [] 250 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 251 | self.base_width, previous_dilation, norm_layer, attention_module)) 252 | self.inplanes = planes * block.expansion 253 | for _ in range(1, blocks): 254 | layers.append(block(self.inplanes, planes, groups=self.groups, 255 | base_width=self.base_width, dilation=self.dilation, 256 | norm_layer=norm_layer, attention_module=attention_module)) 257 | 258 | return nn.Sequential(*layers) 259 | 260 | 261 | def _freeze_stages(self): 262 | if self.frozen_stages >= 0: 263 | self.bn1.eval() 264 | for m in [self.conv1, self.bn1]: 265 | for param in m.parameters(): 266 | param.requires_grad = False 267 | 268 | for i in range(1, self.frozen_stages + 1): 269 | m = getattr(self, f'layer{i}') 270 | m.eval() 271 | for param in m.parameters(): 272 | param.requires_grad = False 273 | 274 | def init_weights(self, pretrained=None): 275 | """Initialize the weights in backbone. 276 | 277 | Args: 278 | pretrained (str, optional): Path to pre-trained weights. 279 | Defaults to None. 280 | """ 281 | 282 | self.fc = None 283 | self.avgpool = None 284 | if isinstance(pretrained, str): 285 | logger = get_root_logger() 286 | load_checkpoint(self, pretrained, strict=False, logger=logger) 287 | elif pretrained is None: 288 | for m in self.modules(): 289 | if isinstance(m, nn.Conv2d): 290 | kaiming_init(m) 291 | elif isinstance(m, (_BatchNorm, nn.GroupNorm)): 292 | constant_init(m, 1) 293 | 294 | if self.zero_init_residual: 295 | for m in self.modules(): 296 | if isinstance(m, Bottleneck): 297 | constant_init(m.norm3, 0) 298 | elif isinstance(m, BasicBlock): 299 | constant_init(m.norm2, 0) 300 | else: 301 | raise TypeError('pretrained must be a str or None') 302 | 303 | def forward(self, x): 304 | # See note [TorchScript super()] 305 | x = self.conv1(x) 306 | x = self.bn1(x) 307 | x = self.relu1(x) 308 | x = self.maxpool(x) 309 | 310 | outs = [] 311 | 312 | x = self.layer1(x) 313 | outs.append(x) 314 | 315 | x = self.layer2(x) 316 | outs.append(x) 317 | 318 | x = self.layer3(x) 319 | outs.append(x) 320 | 321 | x = self.layer4(x) 322 | outs.append(x) 323 | 324 | return tuple(outs) 325 | 326 | def train(self, mode=True): 327 | """Convert the model into training mode while keep normalization layer 328 | freezed.""" 329 | super(ResNetAM, self).train(mode) 330 | self._freeze_stages() 331 | if mode and self.norm_eval: 332 | for m in self.modules(): 333 | # trick: eval have effect on BatchNorm only 334 | if isinstance(m, _BatchNorm): 335 | m.eval() -------------------------------------------------------------------------------- /main_cifar.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | # system lib 4 | import os 5 | import time 6 | import sys 7 | import argparse 8 | # numerical libs 9 | import random 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | import torch.backends.cudnn as cudnn 14 | import torch.nn.functional as F 15 | from torchvision import datasets, transforms 16 | # models 17 | from thop import profile 18 | from util import AverageMeter, ProgressMeter, accuracy, parse_gpus 19 | from checkpoint import save_checkpoint, load_checkpoint 20 | from networks.cifar import create_net 21 | 22 | 23 | def adjust_learning_rate(optimizer, epoch, warmup=False): 24 | """Adjust the learning rate""" 25 | if epoch <= 81: 26 | lr = 0.01 if warmup and epoch == 0 else args.base_lr 27 | elif epoch <= 122: 28 | lr = args.base_lr * 0.1 29 | else: 30 | lr = args.base_lr * 0.01 31 | 32 | for param_group in optimizer.param_groups: 33 | param_group["lr"] = lr 34 | 35 | 36 | def train(net, optimizer, epoch, data_loader, args): 37 | 38 | learning_rate = optimizer.param_groups[0]["lr"] 39 | 40 | batch_time = AverageMeter('Time', ':6.3f') 41 | data_time = AverageMeter('Data', ':6.3f') 42 | losses = AverageMeter('Loss', ':.4f') 43 | top1 = AverageMeter('Accuracy', ':4.2f') 44 | progress = ProgressMeter( 45 | len(data_loader), 46 | [batch_time, data_time, losses, top1], 47 | prefix="Epoch (Train LR {:6.4f}): [{}] ".format(learning_rate, epoch)) 48 | 49 | net.train() 50 | 51 | tic = time.time() 52 | for batch_idx, (data, target) in enumerate(data_loader): 53 | 54 | data, target = data.to(args.device, non_blocking=True), target.to(args.device, non_blocking=True) 55 | 56 | data_time.update(time.time() - tic) 57 | 58 | optimizer.zero_grad() 59 | output = net(data) 60 | loss = F.cross_entropy(output, target) 61 | loss.backward() 62 | optimizer.step() 63 | 64 | acc = accuracy(output, target) 65 | losses.update(loss.item(), data.size(0)) 66 | top1.update(acc[0].item(), data.size(0)) 67 | 68 | batch_time.update(time.time() - tic) 69 | tic = time.time() 70 | 71 | if (batch_idx+1) % args.disp_iter == 0 or (batch_idx+1) == len(data_loader): 72 | epoch_msg = progress.get_message(batch_idx+1) 73 | print(epoch_msg) 74 | 75 | args.log_file.write(epoch_msg + "\n") 76 | 77 | def validate(net, epoch, data_loader, args): 78 | 79 | batch_time = AverageMeter('Time', ':6.3f') 80 | data_time = AverageMeter('Data', ':6.3f') 81 | losses = AverageMeter('Loss', ':.4f') 82 | top1 = AverageMeter('Accuracy', ':4.2f') 83 | progress = ProgressMeter( 84 | len(data_loader), 85 | [batch_time, data_time, losses, top1], 86 | prefix="Epoch (Valid LR {:6.4f}): [{}] ".format(0, epoch)) 87 | 88 | net.eval() 89 | 90 | with torch.no_grad(): 91 | tic = time.time() 92 | for batch_idx, (data, target) in enumerate(data_loader): 93 | 94 | data, target = data.to(args.device, non_blocking=True), target.to(args.device, non_blocking=True) 95 | 96 | data_time.update(time.time() - tic) 97 | 98 | output = net(data) 99 | loss = F.cross_entropy(output, target) 100 | 101 | acc = accuracy(output, target) 102 | losses.update(loss.item(), data.size(0)) 103 | top1.update(acc[0].item(), data.size(0)) 104 | 105 | batch_time.update(time.time() - tic) 106 | tic = time.time() 107 | 108 | if (batch_idx+1) % args.disp_iter == 0 or (batch_idx+1) == len(data_loader): 109 | epoch_msg = progress.get_message(batch_idx+1) 110 | print(epoch_msg) 111 | 112 | args.log_file.write(epoch_msg + "\n") 113 | 114 | print('-------- Mean Accuracy {top1.avg:.3f} --------'.format(top1=top1)) 115 | 116 | return top1.avg 117 | 118 | def main(args): 119 | 120 | if len(args.gpu_ids) > 0: 121 | assert(torch.cuda.is_available()) 122 | cudnn.benchmark = True 123 | kwargs = {"num_workers": args.workers, "pin_memory": True} 124 | args.device = torch.device("cuda:{}".format(args.gpu_ids[0])) 125 | else: 126 | kwargs = {} 127 | args.device = torch.device("cpu") 128 | 129 | normlizer = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 130 | 131 | print("Building dataset: " + args.dataset) 132 | 133 | if args.dataset == "cifar10": 134 | args.num_class = 10 135 | train_loader = torch.utils.data.DataLoader( 136 | datasets.CIFAR10(args.dataset_dir, train=True, download=True, 137 | transform=transforms.Compose([ 138 | transforms.Pad(4), 139 | transforms.RandomCrop(32), 140 | transforms.RandomHorizontalFlip(), 141 | transforms.ToTensor(), 142 | normlizer])), 143 | batch_size=args.batch_size, shuffle=True, **kwargs) 144 | 145 | test_loader = torch.utils.data.DataLoader( 146 | datasets.CIFAR10(args.dataset_dir, train=False, transform=transforms.Compose([ 147 | transforms.ToTensor(), 148 | normlizer])), 149 | batch_size=100, shuffle=False, **kwargs) 150 | 151 | elif args.dataset == "cifar100": 152 | args.num_class = 100 153 | train_loader = torch.utils.data.DataLoader( 154 | datasets.CIFAR100(args.dataset_dir, train=True, download=True, 155 | transform=transforms.Compose([ 156 | transforms.Pad(4), 157 | transforms.RandomCrop(32), 158 | transforms.RandomHorizontalFlip(), 159 | transforms.ToTensor(), 160 | normlizer])), 161 | batch_size=args.batch_size, shuffle=True, **kwargs) 162 | 163 | test_loader = torch.utils.data.DataLoader( 164 | datasets.CIFAR100(args.dataset_dir, train=False, transform=transforms.Compose([ 165 | transforms.ToTensor(), 166 | normlizer])), 167 | batch_size=100, shuffle=False, **kwargs) 168 | 169 | net = create_net(args) 170 | 171 | print(net) 172 | 173 | optimizer = optim.SGD(net.parameters(), lr=args.base_lr, momentum=args.beta1, weight_decay=args.weight_decay) 174 | 175 | if args.resume: 176 | net, optimizer, best_acc, start_epoch = load_checkpoint(args, net, optimizer) 177 | else: 178 | start_epoch = 0 179 | best_acc = 0 180 | 181 | x = torch.randn(1, 3, 32, 32) 182 | flops, params = profile(net, inputs=(x,)) 183 | 184 | print("Number of params: %.6fM" % (params / 1e6)) 185 | print("Number of FLOPs: %.6fG" % (flops / 1e9)) 186 | 187 | args.log_file.write("Network - " + args.arch + "\n") 188 | args.log_file.write("Attention Module - " + args.attention_type + "\n") 189 | args.log_file.write("Params - %.6fM" % (params / 1e6) + "\n") 190 | args.log_file.write("FLOPs - %.6fG" % (flops / 1e9) + "\n") 191 | args.log_file.write("--------------------------------------------------" + "\n") 192 | 193 | if len(args.gpu_ids) > 0: 194 | net.to(args.gpu_ids[0]) 195 | net = torch.nn.DataParallel(net, args.gpu_ids) # multi-GPUs 196 | 197 | for epoch in range(start_epoch, args.num_epoch): 198 | # if args.wrn: 199 | # adjust_learning_rate_wrn(optimizer, epoch, args.warmup) 200 | # else: 201 | adjust_learning_rate(optimizer, epoch, args.warmup) 202 | 203 | train(net, optimizer, epoch, train_loader, args) 204 | epoch_acc = validate(net, epoch, test_loader, args) 205 | 206 | is_best = epoch_acc > best_acc 207 | best_acc = max(epoch_acc, best_acc) 208 | 209 | save_checkpoint({ 210 | "epoch": epoch + 1, 211 | "arch": args.arch, 212 | "state_dict": net.module.cpu().state_dict(), 213 | "best_acc": best_acc, 214 | "optimizer" : optimizer.state_dict(), 215 | }, is_best, epoch, save_path=args.ckpt) 216 | 217 | net.to(args.device) 218 | 219 | args.log_file.write("--------------------------------------------------" + "\n") 220 | 221 | args.log_file.write("best accuracy %4.2f" % best_acc) 222 | 223 | print("Job Done!") 224 | 225 | if __name__ == "__main__": 226 | 227 | parser = argparse.ArgumentParser(description="CIFAR baseline") 228 | 229 | # Model settings 230 | parser.add_argument("--arch", type=str, default="resnet18", 231 | help="network architecture (default: resnet18)") 232 | parser.add_argument("--num_base_filters", type=int, default=16, 233 | help="network base filer numbers (default: 16)") 234 | parser.add_argument("--expansion", type=float, default=1, 235 | help="expansion factor for the mid-layer in resnet-like") 236 | parser.add_argument("--block_type", type=str, default="basic", 237 | help="building block for network, e.g., basic or bottlenect") 238 | parser.add_argument("--attention_type", type=str, default="none", 239 | help="attention type in building block (possible choices none | se | cbam | simam )") 240 | parser.add_argument("--attention_param", type=float, default=4, 241 | help="attention parameter (reduction in CBAM and SE, e_lambda in simam)") 242 | 243 | # Dataset settings 244 | parser.add_argument("--dataset", type=str, default="cifar10", 245 | help="training dataset (default: cifar10)") 246 | parser.add_argument("--dataset_dir", type=str, default="data", 247 | help="data set path (default: data)") 248 | parser.add_argument("--workers", default=16, type=int, 249 | help="number of data loading works") 250 | 251 | # Optimizion settings 252 | parser.add_argument("--gpu_ids", default="0", 253 | help="gpus to use, e.g. 0-3 or 0,1,2,3") 254 | parser.add_argument("--batch_size", type=int, default=128, 255 | help="batch size for training and validation (default: 128)") 256 | parser.add_argument("--num_epoch", type=int, default=164, 257 | help="number of epochs to train (default: 164)") 258 | parser.add_argument("--resume", default="", type=str, 259 | help="path to checkpoint for continous training (default: none)") 260 | parser.add_argument("--optim", default="SGD", 261 | help="optimizer") 262 | parser.add_argument("--base_lr", type=float, default=0.1, 263 | help="learning rate (default: 0.1)") 264 | parser.add_argument("--beta1", default=0.9, type=float, 265 | help="momentum for sgd, beta1 for adam") 266 | parser.add_argument("--weight_decay", type=float, default=5e-4, 267 | help="SGD weight decay (default: 5e-4)") 268 | parser.add_argument("--warmup", action="store_true", 269 | help="warmup for deeper network") 270 | parser.add_argument("--wrn", action="store_true", 271 | help="wider resnet for training") 272 | 273 | # Misc 274 | parser.add_argument("--seed", type=int, default=1, 275 | help="random seed (default: 1)") 276 | parser.add_argument("--disp_iter", type=int, default=100, 277 | help="frequence to display training status (default: 100)") 278 | parser.add_argument("--ckpt", default="./ckpts/", 279 | help="folder to output checkpoints") 280 | 281 | args = parser.parse_args() 282 | args.gpu_ids = parse_gpus(args.gpu_ids) 283 | 284 | args.ckpt += args.dataset 285 | args.ckpt += "-" + args.arch 286 | args.ckpt += "-" + args.block_type 287 | if args.attention_type.lower() != "none": 288 | args.ckpt += "-" + args.attention_type 289 | if args.attention_type.lower() != "none": 290 | args.ckpt += "-param" + str(args.attention_param) 291 | args.ckpt += "-nfilters" + str(args.num_base_filters) 292 | args.ckpt += "-expansion" + str(args.expansion) 293 | args.ckpt += "-baselr" + str(args.base_lr) 294 | args.ckpt += "-rseed" + str(args.seed) 295 | for key, val in vars(args).items(): 296 | print("{:16} {}".format(key, val)) 297 | 298 | if not os.path.isdir(args.ckpt): 299 | os.makedirs(args.ckpt) 300 | 301 | # write to file 302 | args.log_file = open(os.path.join(args.ckpt, "log_file.txt"), mode="w") 303 | 304 | random.seed(args.seed) 305 | torch.manual_seed(args.seed) 306 | 307 | main(args) 308 | 309 | args.log_file.close() 310 | -------------------------------------------------------------------------------- /networks/imagenet/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import functools 4 | # from .utils import load_state_dict_from_url 5 | # note the refinement module 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 18 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 19 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 20 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 21 | } 22 | 23 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 24 | """3x3 convolution with padding""" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 26 | padding=dilation, groups=groups, bias=False, dilation=dilation) 27 | 28 | 29 | def conv1x1(in_planes, out_planes, stride=1): 30 | """1x1 convolution""" 31 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 32 | 33 | 34 | class BasicBlock(nn.Module): 35 | expansion = 1 36 | 37 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 38 | base_width=64, dilation=1, norm_layer=None, attention_module=None): 39 | super(BasicBlock, self).__init__() 40 | if norm_layer is None: 41 | norm_layer = nn.BatchNorm2d 42 | if groups != 1 or base_width != 64: 43 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 44 | if dilation > 1: 45 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 46 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 47 | self.conv1 = conv3x3(inplanes, planes, stride) 48 | self.bn1 = norm_layer(planes) 49 | self.relu = nn.ReLU(inplace=True) 50 | self.conv2 = conv3x3(planes, planes) 51 | self.bn2 = norm_layer(planes) 52 | self.downsample = downsample 53 | self.stride = stride 54 | 55 | if attention_module is not None: 56 | if type(attention_module) == functools.partial: 57 | module_name = attention_module.func.get_module_name() 58 | else: 59 | module_name = attention_module.get_module_name() 60 | 61 | 62 | if module_name == "simam": 63 | self.conv2 = nn.Sequential( 64 | self.conv2, 65 | attention_module(planes) 66 | ) 67 | else: 68 | self.bn2 = nn.Sequential( 69 | self.bn2, 70 | attention_module(planes) 71 | ) 72 | 73 | def forward(self, x): 74 | identity = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | 83 | if self.downsample is not None: 84 | identity = self.downsample(x) 85 | 86 | out += identity 87 | out = self.relu(out) 88 | 89 | return out 90 | 91 | 92 | class Bottleneck(nn.Module): 93 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 94 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 95 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 96 | # This variant is also known as ResNet V1.5 and improves accuracy according to 97 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 98 | 99 | expansion = 4 100 | 101 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 102 | base_width=64, dilation=1, norm_layer=None, attention_module=None): 103 | super(Bottleneck, self).__init__() 104 | if norm_layer is None: 105 | norm_layer = nn.BatchNorm2d 106 | width = int(planes * (base_width / 64.)) * groups 107 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 108 | self.conv1 = conv1x1(inplanes, width) 109 | self.bn1 = norm_layer(width) 110 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 111 | self.bn2 = norm_layer(width) 112 | self.conv3 = conv1x1(width, planes * self.expansion) 113 | self.bn3 = norm_layer(planes * self.expansion) 114 | self.relu = nn.ReLU(inplace=True) 115 | self.downsample = downsample 116 | self.stride = stride 117 | 118 | if attention_module is not None: 119 | if type(attention_module) == functools.partial: 120 | module_name = attention_module.func.get_module_name() 121 | else: 122 | module_name = attention_module.get_module_name() 123 | 124 | if module_name == "simam": 125 | self.conv2 = nn.Sequential( 126 | self.conv2, 127 | attention_module(width) 128 | ) 129 | else: 130 | self.bn3 = nn.Sequential( 131 | self.bn3, 132 | attention_module(planes * self.expansion) 133 | ) 134 | 135 | def forward(self, x): 136 | identity = x 137 | 138 | out = self.conv1(x) 139 | out = self.bn1(out) 140 | out = self.relu(out) 141 | 142 | out = self.conv2(out) 143 | out = self.bn2(out) 144 | out = self.relu(out) 145 | 146 | out = self.conv3(out) 147 | out = self.bn3(out) 148 | 149 | if self.downsample is not None: 150 | identity = self.downsample(x) 151 | 152 | out += identity 153 | out = self.relu(out) 154 | 155 | return out 156 | 157 | 158 | class ResNet(nn.Module): 159 | 160 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 161 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 162 | norm_layer=None, attention_module=None, 163 | deep_stem=False, stem_width=32, avg_down=False): 164 | super(ResNet, self).__init__() 165 | if norm_layer is None: 166 | norm_layer = nn.BatchNorm2d 167 | 168 | self._norm_layer = norm_layer 169 | 170 | self.inplanes = stem_width*2 if deep_stem else 64 171 | 172 | self.dilation = 1 173 | 174 | self.groups = groups 175 | self.base_width = width_per_group 176 | 177 | if replace_stride_with_dilation is None: 178 | # each element in the tuple indicates if we should replace 179 | # the 2x2 stride with a dilated convolution instead 180 | replace_stride_with_dilation = [False, False, False] 181 | if len(replace_stride_with_dilation) != 3: 182 | raise ValueError("replace_stride_with_dilation should be None " 183 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 184 | 185 | if deep_stem: 186 | self.conv1 = nn.Sequential( 187 | conv3x3(3, stem_width, stride=2), 188 | norm_layer(stem_width), 189 | nn.ReLU(), 190 | conv3x3(stem_width, stem_width, stride=1), 191 | norm_layer(stem_width), 192 | nn.ReLU(), 193 | conv3x3(stem_width, self.inplanes, stride=1), 194 | ) 195 | else: 196 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) 197 | 198 | self.bn1 = norm_layer(self.inplanes if not deep_stem else stem_width*2) 199 | self.relu = nn.ReLU(inplace=True) 200 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 201 | 202 | self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down, 203 | attention_module=attention_module) 204 | 205 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down, 206 | dilate=replace_stride_with_dilation[0], 207 | attention_module=attention_module) 208 | 209 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, avg_down=avg_down, 210 | dilate=replace_stride_with_dilation[1], 211 | attention_module=attention_module) 212 | 213 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, avg_down=avg_down, 214 | dilate=replace_stride_with_dilation[2], 215 | attention_module=attention_module) 216 | 217 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 218 | self.fc = nn.Linear(512 * block.expansion, num_classes) 219 | 220 | # print(self.modules) 221 | 222 | for m in self.modules(): 223 | # print(m) 224 | if isinstance(m, nn.Conv2d): 225 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 226 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 227 | nn.init.constant_(m.weight, 1) 228 | nn.init.constant_(m.bias, 0) 229 | 230 | # Zero-initialize the last BN in each residual branch, 231 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 232 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 233 | if zero_init_residual: 234 | for m in self.modules(): 235 | if isinstance(m, Bottleneck): 236 | nn.init.constant_(m.bn3.weight, 0) 237 | elif isinstance(m, BasicBlock): 238 | nn.init.constant_(m.bn2.weight, 0) 239 | 240 | def _make_layer(self, block, planes, blocks, stride=1, avg_down=False, dilate=False, attention_module=None): 241 | norm_layer = self._norm_layer 242 | downsample = None 243 | previous_dilation = self.dilation 244 | if dilate: 245 | self.dilation *= stride 246 | stride = 1 247 | if stride != 1 or self.inplanes != planes * block.expansion: 248 | if avg_down and stride != 1: 249 | downsample = nn.Sequential( 250 | nn.AvgPool2d(kernel_size=stride, stride=stride, count_include_pad=False, ceil_mode=True), 251 | conv1x1(self.inplanes, planes * block.expansion, 1), 252 | norm_layer(planes * block.expansion) 253 | ) 254 | else: 255 | downsample = nn.Sequential( 256 | conv1x1(self.inplanes, planes * block.expansion, stride=stride), 257 | norm_layer(planes * block.expansion) 258 | ) 259 | 260 | 261 | layers = [] 262 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 263 | self.base_width, previous_dilation, norm_layer, attention_module)) 264 | self.inplanes = planes * block.expansion 265 | for _ in range(1, blocks): 266 | layers.append(block(self.inplanes, planes, groups=self.groups, 267 | base_width=self.base_width, dilation=self.dilation, 268 | norm_layer=norm_layer, attention_module=attention_module)) 269 | 270 | return nn.Sequential(*layers) 271 | 272 | def _forward_impl(self, x): 273 | # See note [TorchScript super()] 274 | x = self.conv1(x) 275 | x = self.bn1(x) 276 | x = self.relu(x) 277 | x = self.maxpool(x) 278 | x = self.layer1(x) 279 | x = self.layer2(x) 280 | x = self.layer3(x) 281 | x = self.layer4(x) 282 | 283 | x = self.avgpool(x) 284 | x = torch.flatten(x, 1) 285 | x = self.fc(x) 286 | 287 | return x 288 | 289 | def forward(self, x): 290 | return self._forward_impl(x) 291 | 292 | 293 | def _resnet(arch, block, layers, **kwargs): 294 | model = ResNet(block, layers, **kwargs) 295 | # if pretrained: 296 | # state_dict = load_state_dict_from_url(model_urls[arch], 297 | # progress=progress) 298 | # model.load_state_dict(state_dict) 299 | return model 300 | 301 | 302 | def resnet18(**kwargs): 303 | r"""ResNet-18 model from 304 | `"Deep Residual Learning for Image Recognition" `_ 305 | 306 | Args: 307 | pretrained (bool): If True, returns a model pre-trained on ImageNet 308 | progress (bool): If True, displays a progress bar of the download to stderr 309 | """ 310 | 311 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], **kwargs) 312 | 313 | 314 | 315 | def resnet34(**kwargs): 316 | r"""ResNet-34 model from 317 | `"Deep Residual Learning for Image Recognition" `_ 318 | 319 | Args: 320 | pretrained (bool): If True, returns a model pre-trained on ImageNet 321 | progress (bool): If True, displays a progress bar of the download to stderr 322 | """ 323 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], **kwargs) 324 | 325 | 326 | 327 | def resnet50(**kwargs): 328 | r"""ResNet-50 model from 329 | `"Deep Residual Learning for Image Recognition" `_ 330 | 331 | Args: 332 | pretrained (bool): If True, returns a model pre-trained on ImageNet 333 | progress (bool): If True, displays a progress bar of the download to stderr 334 | """ 335 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], **kwargs) 336 | 337 | 338 | 339 | def resnet101(**kwargs): 340 | r"""ResNet-101 model from 341 | `"Deep Residual Learning for Image Recognition" `_ 342 | 343 | Args: 344 | pretrained (bool): If True, returns a model pre-trained on ImageNet 345 | progress (bool): If True, displays a progress bar of the download to stderr 346 | """ 347 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], **kwargs) 348 | 349 | 350 | 351 | def resnet152(**kwargs): 352 | r"""ResNet-152 model from 353 | `"Deep Residual Learning for Image Recognition" `_ 354 | 355 | Args: 356 | pretrained (bool): If True, returns a model pre-trained on ImageNet 357 | progress (bool): If True, displays a progress bar of the download to stderr 358 | """ 359 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], **kwargs) 360 | 361 | 362 | 363 | def resnext50_32x4d(**kwargs): 364 | r"""ResNeXt-50 32x4d model from 365 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 366 | Args: 367 | pretrained (bool): If True, returns a model pre-trained on ImageNet 368 | progress (bool): If True, displays a progress bar of the download to stderr 369 | """ 370 | kwargs['groups'] = 32 371 | kwargs['width_per_group'] = 4 372 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], **kwargs) 373 | 374 | 375 | def resnext101_64x4d(**kwargs): 376 | r"""ResNeXt-50 32x4d model from 377 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 378 | Args: 379 | pretrained (bool): If True, returns a model pre-trained on ImageNet 380 | progress (bool): If True, displays a progress bar of the download to stderr 381 | """ 382 | kwargs['groups'] = 64 383 | kwargs['width_per_group'] = 4 384 | return _resnet('resnext101_64x4d', Bottleneck, [3, 4, 23, 3], **kwargs) 385 | 386 | 387 | 388 | def resnet50d(**kwargs): 389 | r"""ResNet-50 model from 390 | `"Deep Residual Learning for Image Recognition" `_ 391 | 392 | Args: 393 | pretrained (bool): If True, returns a model pre-trained on ImageNet 394 | progress (bool): If True, displays a progress bar of the download to stderr 395 | """ 396 | 397 | return _resnet('resnet50d', Bottleneck, [3, 4, 6, 3], 398 | deep_stem=True, stem_width=32, avg_down=True, 399 | **kwargs) -------------------------------------------------------------------------------- /main_imagenet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | import time 6 | import warnings 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.distributed as dist 13 | import torch.optim 14 | import torch.multiprocessing as mp 15 | import torch.utils.data 16 | import torch.utils.data.distributed 17 | import torchvision.transforms as transforms 18 | import torchvision.datasets as datasets 19 | import torchvision.models as models 20 | import numpy as np 21 | 22 | from util import AverageMeter, ProgressMeter, accuracy, parse_gpus 23 | from checkpoint import save_checkpoint, load_checkpoint 24 | from thop import profile 25 | from networks.imagenet import create_net 26 | 27 | 28 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 29 | parser.add_argument('data', metavar='DIR', 30 | help='path to dataset') 31 | parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', 32 | help='model architecture (default: resnet18)') 33 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 34 | help='number of data loading workers (default: 4)') 35 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 36 | help='number of total epochs to run') 37 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 38 | help='manual epoch number (useful on restarts)') 39 | parser.add_argument('-b', '--batch-size', default=256, type=int, 40 | metavar='N', 41 | help='mini-batch size (default: 256), this is the total ' 42 | 'batch size of all GPUs on the current node when ' 43 | 'using Data Parallel or Distributed Data Parallel') 44 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 45 | metavar='LR', help='initial learning rate', dest='lr') 46 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 47 | help='momentum') 48 | parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, 49 | metavar='W', help='weight decay (default: 1e-4)', 50 | dest='weight_decay') 51 | parser.add_argument('-p', '--print-freq', default=10, type=int, 52 | metavar='N', help='print frequency (default: 10)') 53 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 54 | help='path to latest checkpoint (default: none)') 55 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 56 | help='evaluate model on validation set') 57 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 58 | help='use pre-trained model') 59 | parser.add_argument('--world-size', default=-1, type=int, 60 | help='number of nodes for distributed training') 61 | parser.add_argument('--rank', default=-1, type=int, 62 | help='node rank for distributed training') 63 | parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, 64 | help='url used to set up distributed training') 65 | parser.add_argument('--dist-backend', default='nccl', type=str, 66 | help='distributed backend') 67 | parser.add_argument('--seed', default=None, type=int, 68 | help='seed for initializing training. ') 69 | parser.add_argument('--gpu', default="0", 70 | help='GPU id to use.') 71 | parser.add_argument('--multiprocessing-distributed', action='store_true', 72 | help='Use multi-processing distributed training to launch ' 73 | 'N processes per node, which has N GPUs. This is the ' 74 | 'fastest way to use PyTorch for either single node or ' 75 | 'multi node data parallel training') 76 | 77 | parser.add_argument("--ckpt", default="./ckpts/", 78 | help="folder to output checkpoints") 79 | parser.add_argument("--attention_type", type=str, default="none", 80 | help="attention type (possible choices none | se | cbam | simam)") 81 | parser.add_argument("--attention_param", type=float, default=4, 82 | help="attention parameter (reduction factor in se and cbam, e_lambda in simam)") 83 | parser.add_argument("--log_freq", type=int, default=500, 84 | help="log frequency to file") 85 | parser.add_argument("--cos_lr", action='store_true', 86 | help='use cosine learning rate') 87 | parser.add_argument("--save_weights", default=None, type=str, metavar='PATH', 88 | help='save weights by CPU for mmdetection') 89 | 90 | 91 | best_acc1 = 0 92 | 93 | 94 | def main(): 95 | args = parser.parse_args() 96 | 97 | args.ckpt += "imagenet" 98 | args.ckpt += "-" + args.arch 99 | if args.attention_type.lower() != "none": 100 | args.ckpt += "-" + args.attention_type 101 | if args.attention_type.lower() != "none": 102 | args.ckpt += "-param" + str(args.attention_param) 103 | 104 | 105 | args.gpu = parse_gpus(args.gpu) 106 | if args.gpu is not None: 107 | args.device = torch.device("cuda:{}".format(args.gpu[0])) 108 | else: 109 | args.device = torch.device("cpu") 110 | 111 | 112 | if args.seed is not None: 113 | random.seed(args.seed) 114 | torch.manual_seed(args.seed) 115 | cudnn.deterministic = True 116 | warnings.warn('You have chosen to seed training. ' 117 | 'This will turn on the CUDNN deterministic setting, ' 118 | 'which can slow down your training considerably! ' 119 | 'You may see unexpected behavior when restarting ' 120 | 'from checkpoints.') 121 | args.ckpt += '-seed' + str(args.seed) 122 | 123 | if not os.path.isdir(args.ckpt): 124 | os.makedirs(args.ckpt) 125 | 126 | if args.gpu is not None: 127 | warnings.warn('You have chosen a specific GPU. This will completely ' 128 | 'disable data parallelism.') 129 | 130 | if args.dist_url == "env://" and args.world_size == -1: 131 | args.world_size = int(os.environ["WORLD_SIZE"]) 132 | 133 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 134 | 135 | ngpus_per_node = torch.cuda.device_count() 136 | if args.multiprocessing_distributed: 137 | # Since we have ngpus_per_node processes per node, the total world_size 138 | # needs to be adjusted accordingly 139 | args.world_size = ngpus_per_node * args.world_size 140 | # Use torch.multiprocessing.spawn to launch distributed processes: the 141 | # main_worker process function 142 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 143 | else: 144 | # Simply call main_worker function 145 | main_worker(args.gpu, ngpus_per_node, args) 146 | 147 | 148 | def main_worker(gpu, ngpus_per_node, args): 149 | global best_acc1 150 | 151 | if args.gpu is not None: 152 | print("Use GPU: {} for training".format(args.gpu)) 153 | 154 | if args.distributed: 155 | if args.dist_url == "env://" and args.rank == -1: 156 | args.rank = int(os.environ["RANK"]) 157 | if args.multiprocessing_distributed: 158 | # For multiprocessing distributed training, rank needs to be the 159 | # global rank among all the processes 160 | args.rank = args.rank * ngpus_per_node + gpu 161 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 162 | world_size=args.world_size, rank=args.rank) 163 | # create model 164 | model = create_net(args) 165 | 166 | 167 | x = torch.randn(1, 3, 224, 224) 168 | flops, params = profile(model, inputs=(x,)) 169 | 170 | print("model [%s] - params: %.6fM" % (args.arch, params / 1e6)) 171 | print("model [%s] - FLOPs: %.6fG" % (args.arch, flops / 1e9)) 172 | 173 | log_file = os.path.join(args.ckpt, "log.txt") 174 | 175 | if os.path.exists(log_file): 176 | args.log_file = open(log_file, mode="a") 177 | else: 178 | args.log_file = open(log_file, mode="w") 179 | args.log_file.write("Network - " + args.arch + "\n") 180 | args.log_file.write("Attention Module - " + args.attention_type + "\n") 181 | args.log_file.write("Params - " % str(params) + "\n") 182 | args.log_file.write("FLOPs - " % str(flops) + "\n") 183 | args.log_file.write("--------------------------------------------------" + "\n") 184 | 185 | args.log_file.close() 186 | 187 | 188 | if not torch.cuda.is_available(): 189 | print('using CPU, this will be slow') 190 | elif args.distributed: 191 | # For multiprocessing distributed, DistributedDataParallel constructor 192 | # should always set the single device scope, otherwise, 193 | # DistributedDataParallel will use all available devices. 194 | if args.gpu is not None: 195 | torch.cuda.set_device(args.device) 196 | model.cuda(args.gpu) 197 | # When using a single GPU per process and per 198 | # DistributedDataParallel, we need to divide the batch size 199 | # ourselves based on the total number of GPUs we have 200 | args.batch_size = int(args.batch_size / ngpus_per_node) 201 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 202 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 203 | else: 204 | model.cuda() 205 | # DistributedDataParallel will divide and allocate batch_size to all 206 | # available GPUs if device_ids are not set 207 | model = torch.nn.parallel.DistributedDataParallel(model) 208 | elif args.gpu is not None: 209 | torch.cuda.set_device(args.device) 210 | model = model.to(args.gpu[0]) 211 | model = torch.nn.DataParallel(model, args.gpu) 212 | 213 | print(model) 214 | 215 | # define loss function (criterion) and optimizer 216 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 217 | 218 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 219 | momentum=args.momentum, 220 | weight_decay=args.weight_decay) 221 | if args.resume: 222 | model, optimizer, best_acc1, start_epoch = load_checkpoint(args, model, optimizer) 223 | args.start_epoch = start_epoch 224 | 225 | cudnn.benchmark = True 226 | 227 | # Data loading code 228 | traindir = os.path.join(args.data, 'train') 229 | valdir = os.path.join(args.data, 'val') 230 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 231 | std=[0.229, 0.224, 0.225]) 232 | 233 | train_dataset = datasets.ImageFolder( 234 | traindir, 235 | transforms.Compose([ 236 | transforms.RandomResizedCrop(224), 237 | transforms.RandomHorizontalFlip(), 238 | transforms.ToTensor(), 239 | normalize, 240 | ])) 241 | 242 | if args.distributed: 243 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 244 | else: 245 | train_sampler = None 246 | 247 | train_loader = torch.utils.data.DataLoader( 248 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 249 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 250 | 251 | val_loader = torch.utils.data.DataLoader( 252 | datasets.ImageFolder(valdir, transforms.Compose([ 253 | transforms.Resize(256), 254 | transforms.CenterCrop(224), 255 | transforms.ToTensor(), 256 | normalize, 257 | ])), 258 | batch_size=args.batch_size, shuffle=False, 259 | num_workers=args.workers, pin_memory=True) 260 | 261 | if args.save_weights is not None: # "deparallelize" saved weights 262 | print("=> saving 'deparallelized' weights [%s]" % args.save_weights) 263 | model = model.module 264 | model = model.cpu() 265 | torch.save({'state_dict': model.state_dict()}, args.save_weights, _use_new_zipfile_serialization=False) 266 | return 267 | 268 | 269 | if args.evaluate: 270 | args.log_file = open(log_file, mode="a") 271 | validate(val_loader, model, criterion, args) 272 | args.log_file.close() 273 | return 274 | 275 | if args.cos_lr: 276 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs) 277 | for epoch in range(args.start_epoch): 278 | scheduler.step() 279 | 280 | 281 | for epoch in range(args.start_epoch, args.epochs): 282 | 283 | args.log_file = open(log_file, mode="a") 284 | 285 | if args.distributed: 286 | train_sampler.set_epoch(epoch) 287 | 288 | if(not args.cos_lr): 289 | adjust_learning_rate(optimizer, epoch, args) 290 | else: 291 | scheduler.step() 292 | print('[%03d] %.5f'%(epoch, scheduler.get_lr()[0])) 293 | 294 | 295 | # train for one epoch 296 | train(train_loader, model, criterion, optimizer, epoch, args) 297 | 298 | # evaluate on validation set 299 | acc1 = validate(val_loader, model, criterion, args) 300 | 301 | # remember best acc@1 and save checkpoint 302 | is_best = acc1 > best_acc1 303 | best_acc1 = max(acc1, best_acc1) 304 | 305 | args.log_file.close() 306 | 307 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 308 | and args.rank % ngpus_per_node == 0): 309 | 310 | save_checkpoint({ 311 | "epoch": epoch + 1, 312 | "arch": args.arch, 313 | "state_dict": model.state_dict(), 314 | "best_acc": best_acc1, 315 | "optimizer" : optimizer.state_dict(), 316 | }, is_best, epoch, save_path=args.ckpt) 317 | 318 | 319 | def train(train_loader, model, criterion, optimizer, epoch, args): 320 | batch_time = AverageMeter('Time', ':6.3f') 321 | data_time = AverageMeter('Data', ':6.3f') 322 | losses = AverageMeter('Loss', ':.4e') 323 | top1 = AverageMeter('Acc@1', ':6.2f') 324 | top5 = AverageMeter('Acc@5', ':6.2f') 325 | progress = ProgressMeter( 326 | len(train_loader), 327 | [batch_time, data_time, losses, top1, top5], 328 | prefix="Epoch: [{}]".format(epoch)) 329 | 330 | param_groups = optimizer.param_groups[0] 331 | curr_lr = param_groups["lr"] 332 | 333 | # switch to train mode 334 | model.train() 335 | 336 | end = time.time() 337 | for i, (images, target) in enumerate(train_loader): 338 | # measure data loading time 339 | data_time.update(time.time() - end) 340 | 341 | if args.gpu is not None: 342 | images = images.to(args.device, non_blocking=True) 343 | if torch.cuda.is_available(): 344 | target = target.to(args.device, non_blocking=True) 345 | 346 | # compute output 347 | output = model(images) 348 | loss = criterion(output, target) 349 | 350 | # measure accuracy and record loss 351 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 352 | losses.update(loss.item(), images.size(0)) 353 | top1.update(acc1[0], images.size(0)) 354 | top5.update(acc5[0], images.size(0)) 355 | 356 | # compute gradient and do SGD step 357 | optimizer.zero_grad() 358 | loss.backward() 359 | optimizer.step() 360 | 361 | # measure elapsed time 362 | batch_time.update(time.time() - end) 363 | end = time.time() 364 | 365 | if i % args.print_freq == 0: 366 | epoch_msg = progress.get_message(i) 367 | epoch_msg += ("\tLr {:.4f}".format(curr_lr)) 368 | print(epoch_msg) 369 | 370 | if i % args.log_freq == 0: 371 | args.log_file.write(epoch_msg + "\n") 372 | 373 | 374 | def validate(val_loader, model, criterion, args): 375 | batch_time = AverageMeter('Time', ':6.3f') 376 | losses = AverageMeter('Loss', ':.4e') 377 | top1 = AverageMeter('Acc@1', ':6.2f') 378 | top5 = AverageMeter('Acc@5', ':6.2f') 379 | 380 | progress = ProgressMeter( 381 | len(val_loader), 382 | [batch_time, losses, top1, top5], 383 | prefix='Test: ') 384 | 385 | # switch to evaluate mode 386 | model.eval() 387 | 388 | with torch.no_grad(): 389 | end = time.time() 390 | for i, (images, target) in enumerate(val_loader): 391 | 392 | if args.gpu is not None: 393 | images = images.to(args.device, non_blocking=True) 394 | if torch.cuda.is_available(): 395 | target = target.to(args.device, non_blocking=True) 396 | 397 | # compute outputs 398 | output = model(images) 399 | loss = criterion(output, target) 400 | 401 | # measure accuracy and record loss 402 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 403 | losses.update(loss.item(), images.size(0)) 404 | top1.update(acc1[0], images.size(0)) 405 | top5.update(acc5[0], images.size(0)) 406 | 407 | # measure elapsed time 408 | batch_time.update(time.time() - end) 409 | end = time.time() 410 | 411 | if i % args.print_freq == 0: 412 | epoch_msg = progress.get_message(i) 413 | print(epoch_msg) 414 | 415 | # TODO: this should also be done with the ProgressMeter 416 | # print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' 417 | # .format(top1=top1, top5=top5)) 418 | 419 | epoch_msg = '----------- Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f} -----------'.format(top1=top1, top5=top5) 420 | 421 | print(epoch_msg) 422 | 423 | args.log_file.write(epoch_msg + "\n") 424 | 425 | 426 | return top1.avg 427 | 428 | def adjust_learning_rate(optimizer, epoch, args): 429 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 430 | lr = args.lr * (0.1 ** (epoch // 30)) 431 | for param_group in optimizer.param_groups: 432 | param_group['lr'] = lr 433 | 434 | 435 | if __name__ == '__main__': 436 | main() --------------------------------------------------------------------------------