├── .gitignore ├── LICENSE ├── README.md ├── cifar.py ├── hubconf.py ├── imagenet.py └── senet ├── __init__.py ├── baseline.py ├── se_inception.py ├── se_module.py └── se_resnet.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | __pycache__ 3 | .idea -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Ryuichiro Hataya 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SENet.pytorch 2 | 3 | An implementation of SENet, proposed in **Squeeze-and-Excitation Networks** by Jie Hu, Li Shen and Gang Sun, who are the winners of ILSVRC 2017 classification competition. 4 | 5 | Now SE-ResNet (18, 34, 50, 101, 152/20, 32) and SE-Inception-v3 are implemented. 6 | 7 | * `python cifar.py` runs SE-ResNet20 with Cifar10 dataset. 8 | 9 | * `python imagenet.py` and `python -m torch.distributed.launch --nproc_per_node=${NUM_GPUS} imagenet.py` run SE-ResNet50 with ImageNet(2012) dataset, 10 | + You need to prepare dataset by yourself in `~/.torch/data` or set an enviroment variable `IMAGENET_ROOT=${PATH_TO_YOUR_IMAGENET}` 11 | + First download files and then follow the [instruction](https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md#download-the-imagenet-dataset). 12 | + The number of workers and some hyper parameters are fixed so check and change them if you need. 13 | + This script uses all GPUs available. To specify GPUs, use `CUDA_VISIBLE_DEVICES` variable. (e.g. `CUDA_VISIBLE_DEVICES=1,2` to use GPU 1 and 2) 14 | 15 | For SE-Inception-v3, the input size is required to be 299x299 [as the original Inception](https://github.com/tensorflow/models/tree/master/inception). 16 | 17 | ## Pre-requirements 18 | 19 | The codebase is tested on the following setting. 20 | 21 | * Python>=3.8 22 | * PyTorch>=1.6.0 23 | * torchvision>=0.7 24 | 25 | ### For training 26 | 27 | To run `cifar.py` or `imagenet.py`, you need 28 | 29 | * `pip install git+https://github.com/moskomule/homura@v2020.07` 30 | 31 | ## hub 32 | 33 | You can use some SE-ResNet (`se_resnet{20, 56, 50, 101}`) via `torch.hub`. 34 | 35 | ```python 36 | import torch.hub 37 | hub_model = torch.hub.load( 38 | 'moskomule/senet.pytorch', 39 | 'se_resnet20', 40 | num_classes=10) 41 | ``` 42 | 43 | Also, a pretrained SE-ResNet50 model is available. 44 | 45 | ```python 46 | import torch.hub 47 | hub_model = torch.hub.load( 48 | 'moskomule/senet.pytorch', 49 | 'se_resnet50', 50 | pretrained=True,) 51 | ``` 52 | 53 | ## Results 54 | 55 | ### SE-ResNet20/Cifar10 56 | 57 | ``` 58 | python cifar.py [--baseline] 59 | ``` 60 | 61 | Note that the CIFAR-10 dataset expected to be under `~/.torch/data`. 62 | 63 | | | ResNet20 | SE-ResNet20 (reduction 4 or 8) | 64 | |:------------- | :------------- | :------------- | 65 | |max. test accuracy| 92% | 93% | 66 | 67 | ### SE-ResNet50/ImageNet 68 | 69 | ``` 70 | python [-m torch.distributed.launch --nproc_per_node=${NUM_GPUS}] imagenet.py 71 | ``` 72 | 73 | The option [-m ...] is for distributed training. Note that the Imagenet dataset is expected to be under `~/.torch/data` or specified as `IMAGENET_ROOT=${PATH_TO_IMAGENET}`. 74 | 75 | *The initial learning rate and mini-batch size are different from the original version because of my computational resource* . 76 | 77 | | | ResNet | SE-ResNet | 78 | |:------------- | :------------- | :------------- | 79 | |max. test accuracy(top1)| 76.15 %(*) | 77.06% (**) | 80 | 81 | 82 | + (*): [ResNet-50 in torchvision](https://pytorch.org/docs/stable/torchvision/models.html) 83 | 84 | + (**): When using `imagenet.py` with the `--distributed` setting on 8 GPUs. The weight is [available](https://github.com/moskomule/senet.pytorch/releases/download/archive/seresnet50-60a8950a85b2b.pkl). 85 | 86 | ```python 87 | # !wget https://github.com/moskomule/senet.pytorch/releases/download/archive/seresnet50-60a8950a85b2b.pkl 88 | 89 | senet = se_resnet50(num_classes=1000) 90 | senet.load_state_dict(torch.load("seresnet50-60a8950a85b2b.pkl")) 91 | ``` 92 | 93 | ## Contribution 94 | 95 | I cannot maintain this repository actively, but any contributions are welcome. Feel free to send PRs and issues. 96 | 97 | ## References 98 | 99 | [paper](https://arxiv.org/pdf/1709.01507.pdf) 100 | 101 | [authors' Caffe implementation](https://github.com/hujie-frank/SENet) 102 | -------------------------------------------------------------------------------- /cifar.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | from homura import callbacks, lr_scheduler, optim, reporters 4 | from homura.trainers import SupervisedTrainer as Trainer 5 | from homura.vision import DATASET_REGISTRY 6 | from senet.baseline import resnet20 7 | from senet.se_resnet import se_resnet20 8 | 9 | 10 | def main(): 11 | train_loader, test_loader = DATASET_REGISTRY("cifar10")(args.batch_size, num_workers=args.num_workers) 12 | 13 | if args.baseline: 14 | model = resnet20() 15 | else: 16 | model = se_resnet20(num_classes=10, reduction=args.reduction) 17 | 18 | optimizer = optim.SGD(lr=1e-1, momentum=0.9, weight_decay=1e-4) 19 | scheduler = lr_scheduler.StepLR(80, 0.1) 20 | tqdm_rep = reporters.TQDMReporter(range(args.epochs)) 21 | _callbacks = [tqdm_rep, callbacks.AccuracyCallback()] 22 | with Trainer(model, optimizer, F.cross_entropy, scheduler=scheduler, callbacks=_callbacks) as trainer: 23 | for _ in tqdm_rep: 24 | trainer.train(train_loader) 25 | trainer.test(test_loader) 26 | 27 | 28 | if __name__ == "__main__": 29 | import argparse 30 | 31 | p = argparse.ArgumentParser() 32 | p.add_argument("--epochs", type=int, default=200) 33 | p.add_argument("--batch_size", type=int, default=64) 34 | p.add_argument("--reduction", type=int, default=16) 35 | p.add_argument("--num_workers", type=int, default=4) 36 | p.add_argument("--baseline", action="store_true") 37 | args = p.parse_args() 38 | main() 39 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ["torch", "math"] 2 | 3 | 4 | def se_resnet20(**kwargs): 5 | from senet.se_resnet import se_resnet20 as _se_resnet20 6 | 7 | return _se_resnet20(**kwargs) 8 | 9 | 10 | def se_resnet56(**kwargs): 11 | from senet.se_resnet import se_resnet56 as _se_resnet56 12 | 13 | return _se_resnet56(**kwargs) 14 | 15 | 16 | def se_resnet50(**kwargs): 17 | from senet.se_resnet import se_resnet50 as _se_resnet50 18 | 19 | return _se_resnet50(**kwargs) 20 | 21 | 22 | def se_resnet101(**kwargs): 23 | from senet.se_resnet import se_resnet101 as _se_resnet101 24 | 25 | return _se_resnet101(**kwargs) 26 | -------------------------------------------------------------------------------- /imagenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | from homura import callbacks, init_distributed, lr_scheduler, optim, reporters, is_distributed 5 | from homura.trainers import SupervisedTrainer 6 | from homura.vision import DATASET_REGISTRY 7 | from senet.se_resnet import se_resnet50 8 | 9 | 10 | def main(): 11 | if is_distributed(): 12 | init_distributed() 13 | 14 | model = se_resnet50(num_classes=1000) 15 | 16 | optimizer = optim.SGD(lr=0.6 / 1024 * args.batch_size, momentum=0.9, weight_decay=1e-4) 17 | scheduler = lr_scheduler.MultiStepLR([50, 70]) 18 | train_loader, test_loader = DATASET_REGISTRY("imagenet")(args.batch_size) 19 | 20 | c = [ 21 | callbacks.AccuracyCallback(), 22 | callbacks.AccuracyCallback(k=5), 23 | callbacks.LossCallback(), 24 | callbacks.WeightSave("."), 25 | reporters.TensorboardReporter("."), 26 | reporters.TQDMReporter(range(args.epochs)), 27 | ] 28 | 29 | with SupervisedTrainer(model, optimizer, F.cross_entropy, callbacks=c, scheduler=scheduler,) as trainer: 30 | for _ in c[-1]: 31 | trainer.train(train_loader) 32 | trainer.test(test_loader) 33 | 34 | 35 | if __name__ == "__main__": 36 | import argparse 37 | import warnings 38 | 39 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 40 | 41 | p = argparse.ArgumentParser() 42 | p.add_argument("--epochs", type=int, default=90) 43 | p.add_argument("--batch_size", type=int, default=128) 44 | p.add_argument("--local_rank", type=int, default=-1) 45 | args = p.parse_args() 46 | 47 | main() 48 | -------------------------------------------------------------------------------- /senet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/moskomule/senet.pytorch/8cb2669fec6fa344481726f9199aa611f08c3fbd/senet/__init__.py -------------------------------------------------------------------------------- /senet/baseline.py: -------------------------------------------------------------------------------- 1 | """ 2 | ResNet for CIFAR dataset proposed in He+15, p 7. and 3 | https://github.com/facebook/fb.resnet.torch/blob/master/models/resnet.lua 4 | """ 5 | 6 | import torch.nn as nn 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | "3x3 convolution with padding" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=1, bias=False) 13 | 14 | 15 | class BasicBlock(nn.Module): 16 | def __init__(self, inplanes, planes, stride=1): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | if inplanes != planes: 24 | self.downsample = nn.Sequential(nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False), 25 | nn.BatchNorm2d(planes)) 26 | else: 27 | self.downsample = lambda x: x 28 | self.stride = stride 29 | 30 | def forward(self, x): 31 | residual = self.downsample(x) 32 | out = self.conv1(x) 33 | out = self.bn1(out) 34 | out = self.relu(out) 35 | 36 | out = self.conv2(out) 37 | out = self.bn2(out) 38 | 39 | out += residual 40 | out = self.relu(out) 41 | 42 | return out 43 | 44 | 45 | class PreActBasicBlock(BasicBlock): 46 | def __init__(self, inplanes, planes, stride): 47 | super(PreActBasicBlock, self).__init__(inplanes, planes, stride) 48 | if inplanes != planes: 49 | self.downsample = nn.Sequential(nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)) 50 | else: 51 | self.downsample = lambda x: x 52 | self.bn1 = nn.BatchNorm2d(inplanes) 53 | 54 | def forward(self, x): 55 | residual = self.downsample(x) 56 | out = self.bn1(x) 57 | out = self.relu(out) 58 | out = self.conv1(out) 59 | 60 | out = self.bn2(out) 61 | out = self.relu(out) 62 | out = self.conv2(out) 63 | 64 | out += residual 65 | 66 | return out 67 | 68 | 69 | class ResNet(nn.Module): 70 | def __init__(self, block, n_size, num_classes=10): 71 | super(ResNet, self).__init__() 72 | self.inplane = 16 73 | self.conv1 = nn.Conv2d(3, self.inplane, kernel_size=3, stride=1, padding=1, bias=False) 74 | self.bn1 = nn.BatchNorm2d(self.inplane) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.layer1 = self._make_layer(block, 16, blocks=n_size, stride=1) 77 | self.layer2 = self._make_layer(block, 32, blocks=n_size, stride=2) 78 | self.layer3 = self._make_layer(block, 64, blocks=n_size, stride=2) 79 | self.avgpool = nn.AdaptiveAvgPool2d(1) 80 | self.fc = nn.Linear(64, num_classes) 81 | 82 | self.initialize() 83 | 84 | def initialize(self): 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | nn.init.kaiming_normal_(m.weight) 88 | elif isinstance(m, nn.BatchNorm2d): 89 | nn.init.constant_(m.weight, 1) 90 | nn.init.constant_(m.bias, 0) 91 | 92 | def _make_layer(self, block, planes, blocks, stride): 93 | 94 | strides = [stride] + [1] * (blocks - 1) 95 | layers = [] 96 | for stride in strides: 97 | layers.append(block(self.inplane, planes, stride)) 98 | self.inplane = planes 99 | 100 | return nn.Sequential(*layers) 101 | 102 | def forward(self, x): 103 | x = self.conv1(x) 104 | x = self.bn1(x) 105 | x = self.relu(x) 106 | 107 | x = self.layer1(x) 108 | x = self.layer2(x) 109 | x = self.layer3(x) 110 | 111 | x = self.avgpool(x) 112 | x = x.view(x.size(0), -1) 113 | x = self.fc(x) 114 | 115 | return x 116 | 117 | 118 | class PreActResNet(ResNet): 119 | def __init__(self, block, n_size, num_classes=10): 120 | super(PreActResNet, self).__init__(block, n_size, num_classes) 121 | 122 | self.bn1 = nn.BatchNorm2d(self.inplane) 123 | self.initialize() 124 | 125 | def forward(self, x): 126 | x = self.conv1(x) 127 | x = self.layer1(x) 128 | x = self.layer2(x) 129 | x = self.layer3(x) 130 | 131 | x = self.bn1(x) 132 | x = self.relu(x) 133 | 134 | x = self.avgpool(x) 135 | x = x.view(x.size(0), -1) 136 | x = self.fc(x) 137 | 138 | return x 139 | 140 | 141 | def resnet20(**kwargs): 142 | model = ResNet(BasicBlock, 3, **kwargs) 143 | return model 144 | 145 | 146 | def resnet32(**kwargs): 147 | model = ResNet(BasicBlock, 5, **kwargs) 148 | return model 149 | 150 | 151 | def resnet56(**kwargs): 152 | model = ResNet(BasicBlock, 9, **kwargs) 153 | return model 154 | 155 | 156 | def resnet110(**kwargs): 157 | model = ResNet(BasicBlock, 18, **kwargs) 158 | return model 159 | 160 | 161 | def preact_resnet20(**kwargs): 162 | model = PreActResNet(PreActBasicBlock, 3, **kwargs) 163 | return model 164 | 165 | 166 | def preact_resnet32(**kwargs): 167 | model = PreActResNet(PreActBasicBlock, 5, **kwargs) 168 | return model 169 | 170 | 171 | def preact_resnet56(**kwargs): 172 | model = PreActResNet(PreActBasicBlock, 9, **kwargs) 173 | return model 174 | 175 | 176 | def preact_resnet110(**kwargs): 177 | model = PreActResNet(PreActBasicBlock, 18, **kwargs) 178 | return model 179 | -------------------------------------------------------------------------------- /senet/se_inception.py: -------------------------------------------------------------------------------- 1 | from senet.se_module import SELayer 2 | from torch import nn 3 | from torchvision.models.inception import Inception3 4 | 5 | 6 | class SEInception3(nn.Module): 7 | def __init__(self, num_classes, aux_logits=True, transform_input=False): 8 | super(SEInception3, self).__init__() 9 | model = Inception3(num_classes=num_classes, aux_logits=aux_logits, 10 | transform_input=transform_input) 11 | model.Mixed_5b.add_module("SELayer", SELayer(192)) 12 | model.Mixed_5c.add_module("SELayer", SELayer(256)) 13 | model.Mixed_5d.add_module("SELayer", SELayer(288)) 14 | model.Mixed_6a.add_module("SELayer", SELayer(288)) 15 | model.Mixed_6b.add_module("SELayer", SELayer(768)) 16 | model.Mixed_6c.add_module("SELayer", SELayer(768)) 17 | model.Mixed_6d.add_module("SELayer", SELayer(768)) 18 | model.Mixed_6e.add_module("SELayer", SELayer(768)) 19 | if aux_logits: 20 | model.AuxLogits.add_module("SELayer", SELayer(768)) 21 | model.Mixed_7a.add_module("SELayer", SELayer(768)) 22 | model.Mixed_7b.add_module("SELayer", SELayer(1280)) 23 | model.Mixed_7c.add_module("SELayer", SELayer(2048)) 24 | 25 | self.model = model 26 | 27 | def forward(self, x): 28 | _, _, h, w = x.size() 29 | if (h, w) != (299, 299): 30 | raise ValueError("input size must be (299, 299)") 31 | 32 | return self.model(x) 33 | 34 | 35 | def se_inception_v3(**kwargs): 36 | return SEInception3(**kwargs) 37 | -------------------------------------------------------------------------------- /senet/se_module.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class SELayer(nn.Module): 5 | def __init__(self, channel, reduction=16): 6 | super(SELayer, self).__init__() 7 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 8 | self.fc = nn.Sequential( 9 | nn.Linear(channel, channel // reduction, bias=False), 10 | nn.ReLU(inplace=True), 11 | nn.Linear(channel // reduction, channel, bias=False), 12 | nn.Sigmoid() 13 | ) 14 | 15 | def forward(self, x): 16 | b, c, _, _ = x.size() 17 | y = self.avg_pool(x).view(b, c) 18 | y = self.fc(y).view(b, c, 1, 1) 19 | return x * y.expand_as(x) 20 | -------------------------------------------------------------------------------- /senet/se_resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.hub import load_state_dict_from_url 3 | from torchvision.models import ResNet 4 | from senet.se_module import SELayer 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 9 | 10 | 11 | class SEBasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 15 | base_width=64, dilation=1, norm_layer=None, 16 | *, reduction=16): 17 | super(SEBasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv2 = conv3x3(planes, planes, 1) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.se = SELayer(planes, reduction) 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x): 28 | residual = x 29 | out = self.conv1(x) 30 | out = self.bn1(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv2(out) 34 | out = self.bn2(out) 35 | out = self.se(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | 46 | class SEBottleneck(nn.Module): 47 | expansion = 4 48 | 49 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 50 | base_width=64, dilation=1, norm_layer=None, 51 | *, reduction=16): 52 | super(SEBottleneck, self).__init__() 53 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 54 | self.bn1 = nn.BatchNorm2d(planes) 55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 56 | padding=1, bias=False) 57 | self.bn2 = nn.BatchNorm2d(planes) 58 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 59 | self.bn3 = nn.BatchNorm2d(planes * 4) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.se = SELayer(planes * 4, reduction) 62 | self.downsample = downsample 63 | self.stride = stride 64 | 65 | def forward(self, x): 66 | residual = x 67 | 68 | out = self.conv1(x) 69 | out = self.bn1(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv2(out) 73 | out = self.bn2(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv3(out) 77 | out = self.bn3(out) 78 | out = self.se(out) 79 | 80 | if self.downsample is not None: 81 | residual = self.downsample(x) 82 | 83 | out += residual 84 | out = self.relu(out) 85 | 86 | return out 87 | 88 | 89 | def se_resnet18(num_classes=1_000): 90 | """Constructs a ResNet-18 model. 91 | 92 | Args: 93 | pretrained (bool): If True, returns a model pre-trained on ImageNet 94 | """ 95 | model = ResNet(SEBasicBlock, [2, 2, 2, 2], num_classes=num_classes) 96 | model.avgpool = nn.AdaptiveAvgPool2d(1) 97 | return model 98 | 99 | 100 | def se_resnet34(num_classes=1_000): 101 | """Constructs a ResNet-34 model. 102 | 103 | Args: 104 | pretrained (bool): If True, returns a model pre-trained on ImageNet 105 | """ 106 | model = ResNet(SEBasicBlock, [3, 4, 6, 3], num_classes=num_classes) 107 | model.avgpool = nn.AdaptiveAvgPool2d(1) 108 | return model 109 | 110 | 111 | def se_resnet50(num_classes=1_000, pretrained=False): 112 | """Constructs a ResNet-50 model. 113 | 114 | Args: 115 | pretrained (bool): If True, returns a model pre-trained on ImageNet 116 | """ 117 | model = ResNet(SEBottleneck, [3, 4, 6, 3], num_classes=num_classes) 118 | model.avgpool = nn.AdaptiveAvgPool2d(1) 119 | if pretrained: 120 | model.load_state_dict(load_state_dict_from_url( 121 | "https://github.com/moskomule/senet.pytorch/releases/download/archive/seresnet50-60a8950a85b2b.pkl")) 122 | return model 123 | 124 | 125 | def se_resnet101(num_classes=1_000): 126 | """Constructs a ResNet-101 model. 127 | 128 | Args: 129 | pretrained (bool): If True, returns a model pre-trained on ImageNet 130 | """ 131 | model = ResNet(SEBottleneck, [3, 4, 23, 3], num_classes=num_classes) 132 | model.avgpool = nn.AdaptiveAvgPool2d(1) 133 | return model 134 | 135 | 136 | def se_resnet152(num_classes=1_000): 137 | """Constructs a ResNet-152 model. 138 | 139 | Args: 140 | pretrained (bool): If True, returns a model pre-trained on ImageNet 141 | """ 142 | model = ResNet(SEBottleneck, [3, 8, 36, 3], num_classes=num_classes) 143 | model.avgpool = nn.AdaptiveAvgPool2d(1) 144 | return model 145 | 146 | 147 | class CifarSEBasicBlock(nn.Module): 148 | def __init__(self, inplanes, planes, stride=1, reduction=16): 149 | super(CifarSEBasicBlock, self).__init__() 150 | self.conv1 = conv3x3(inplanes, planes, stride) 151 | self.bn1 = nn.BatchNorm2d(planes) 152 | self.relu = nn.ReLU(inplace=True) 153 | self.conv2 = conv3x3(planes, planes) 154 | self.bn2 = nn.BatchNorm2d(planes) 155 | self.se = SELayer(planes, reduction) 156 | if inplanes != planes: 157 | self.downsample = nn.Sequential(nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False), 158 | nn.BatchNorm2d(planes)) 159 | else: 160 | self.downsample = lambda x: x 161 | self.stride = stride 162 | 163 | def forward(self, x): 164 | residual = self.downsample(x) 165 | out = self.conv1(x) 166 | out = self.bn1(out) 167 | out = self.relu(out) 168 | 169 | out = self.conv2(out) 170 | out = self.bn2(out) 171 | out = self.se(out) 172 | 173 | out += residual 174 | out = self.relu(out) 175 | 176 | return out 177 | 178 | 179 | class CifarSEResNet(nn.Module): 180 | def __init__(self, block, n_size, num_classes=10, reduction=16): 181 | super(CifarSEResNet, self).__init__() 182 | self.inplane = 16 183 | self.conv1 = nn.Conv2d( 184 | 3, self.inplane, kernel_size=3, stride=1, padding=1, bias=False) 185 | self.bn1 = nn.BatchNorm2d(self.inplane) 186 | self.relu = nn.ReLU(inplace=True) 187 | self.layer1 = self._make_layer( 188 | block, 16, blocks=n_size, stride=1, reduction=reduction) 189 | self.layer2 = self._make_layer( 190 | block, 32, blocks=n_size, stride=2, reduction=reduction) 191 | self.layer3 = self._make_layer( 192 | block, 64, blocks=n_size, stride=2, reduction=reduction) 193 | self.avgpool = nn.AdaptiveAvgPool2d(1) 194 | self.fc = nn.Linear(64, num_classes) 195 | self.initialize() 196 | 197 | def initialize(self): 198 | for m in self.modules(): 199 | if isinstance(m, nn.Conv2d): 200 | nn.init.kaiming_normal_(m.weight) 201 | elif isinstance(m, nn.BatchNorm2d): 202 | nn.init.constant_(m.weight, 1) 203 | nn.init.constant_(m.bias, 0) 204 | 205 | def _make_layer(self, block, planes, blocks, stride, reduction): 206 | strides = [stride] + [1] * (blocks - 1) 207 | layers = [] 208 | for stride in strides: 209 | layers.append(block(self.inplane, planes, stride, reduction)) 210 | self.inplane = planes 211 | 212 | return nn.Sequential(*layers) 213 | 214 | def forward(self, x): 215 | x = self.conv1(x) 216 | x = self.bn1(x) 217 | x = self.relu(x) 218 | 219 | x = self.layer1(x) 220 | x = self.layer2(x) 221 | x = self.layer3(x) 222 | 223 | x = self.avgpool(x) 224 | x = x.view(x.size(0), -1) 225 | x = self.fc(x) 226 | 227 | return x 228 | 229 | 230 | class CifarSEPreActResNet(CifarSEResNet): 231 | def __init__(self, block, n_size, num_classes=10, reduction=16): 232 | super(CifarSEPreActResNet, self).__init__( 233 | block, n_size, num_classes, reduction) 234 | self.bn1 = nn.BatchNorm2d(self.inplane) 235 | self.initialize() 236 | 237 | def forward(self, x): 238 | x = self.conv1(x) 239 | x = self.layer1(x) 240 | x = self.layer2(x) 241 | x = self.layer3(x) 242 | 243 | x = self.bn1(x) 244 | x = self.relu(x) 245 | 246 | x = self.avgpool(x) 247 | x = x.view(x.size(0), -1) 248 | x = self.fc(x) 249 | 250 | 251 | def se_resnet20(**kwargs): 252 | """Constructs a ResNet-18 model. 253 | 254 | """ 255 | model = CifarSEResNet(CifarSEBasicBlock, 3, **kwargs) 256 | return model 257 | 258 | 259 | def se_resnet32(**kwargs): 260 | """Constructs a ResNet-34 model. 261 | 262 | """ 263 | model = CifarSEResNet(CifarSEBasicBlock, 5, **kwargs) 264 | return model 265 | 266 | 267 | def se_resnet56(**kwargs): 268 | """Constructs a ResNet-34 model. 269 | 270 | """ 271 | model = CifarSEResNet(CifarSEBasicBlock, 9, **kwargs) 272 | return model 273 | 274 | 275 | def se_preactresnet20(**kwargs): 276 | """Constructs a ResNet-18 model. 277 | 278 | """ 279 | model = CifarSEPreActResNet(CifarSEBasicBlock, 3, **kwargs) 280 | return model 281 | 282 | 283 | def se_preactresnet32(**kwargs): 284 | """Constructs a ResNet-34 model. 285 | 286 | """ 287 | model = CifarSEPreActResNet(CifarSEBasicBlock, 5, **kwargs) 288 | return model 289 | 290 | 291 | def se_preactresnet56(**kwargs): 292 | """Constructs a ResNet-34 model. 293 | 294 | """ 295 | model = CifarSEPreActResNet(CifarSEBasicBlock, 9, **kwargs) 296 | return model 297 | --------------------------------------------------------------------------------