├── .gitignore ├── LICENSE ├── README.md ├── cifar.py ├── hubconf.py ├── imagenet.py ├── senet ├── __init__.py ├── baseline.py ├── se_inception.py ├── se_module.py └── se_resnet.py └── sfnet ├── __init__.py └── sfnet.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | __pycache__ 3 | .idea -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Ryuichiro Hataya 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SFNet.pytorch 2 | 3 | An unofficial implementation of SFNet, proposed in **Semantic Flow for Fast and Accurate Scene Parsing** by Xiangtai Li1*, Ansheng You1*, Zhen Zhu2, Houlong Zhao3, Maoke Yang3, Kuiyuan Yang3, Yunhai Tong1 4 | 5 | 6 | ## Pre-requirements 7 | 8 | * Python>=3.6 9 | * PyTorch>=1.0 10 | * torchvision>=0.3 11 | 12 | 13 | ## References 14 | 15 | [paper](https://arxiv.org/pdf/2002.10120v1.pdf) 16 | 17 | [authors' pytorch implementation](https://github.com/donnyyou/torchcv) 18 | -------------------------------------------------------------------------------- /cifar.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from homura import optim, lr_scheduler, callbacks, reporters 3 | from homura.trainers import SupervisedTrainer as Trainer 4 | from homura.vision.data.loaders import cifar10_loaders 5 | 6 | from senet.baseline import resnet20 7 | from senet.se_resnet import se_resnet20 8 | 9 | 10 | def main(): 11 | train_loader, test_loader = cifar10_loaders(args.batch_size) 12 | 13 | if args.baseline: 14 | model = resnet20() 15 | else: 16 | model = se_resnet20(num_classes=10, reduction=args.reduction) 17 | optimizer = optim.SGD(lr=1e-1, momentum=0.9, weight_decay=1e-4) 18 | scheduler = lr_scheduler.StepLR(80, 0.1) 19 | tqdm_rep = reporters.TQDMReporter(range(args.epochs)) 20 | _callbacks = [tqdm_rep, callbacks.AccuracyCallback()] 21 | with Trainer(model, optimizer, F.cross_entropy, scheduler=scheduler, callbacks=_callbacks) as trainer: 22 | for _ in tqdm_rep: 23 | trainer.train(train_loader) 24 | trainer.test(test_loader) 25 | 26 | 27 | if __name__ == '__main__': 28 | import argparse 29 | 30 | p = argparse.ArgumentParser() 31 | p.add_argument("--epochs", type=int, default=200) 32 | p.add_argument("--batch_size", type=int, default=64) 33 | p.add_argument("--reduction", type=int, default=16) 34 | p.add_argument("--baseline", action="store_true") 35 | args = p.parse_args() 36 | main() 37 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | dependencies = ["torch", "math"] 2 | 3 | 4 | def se_resnet20(**kwargs): 5 | from senet.se_resnet import se_resnet20 as _se_resnet20 6 | 7 | return _se_resnet20(**kwargs) 8 | 9 | 10 | def se_resnet56(**kwargs): 11 | from senet.se_resnet import se_resnet56 as _se_resnet56 12 | 13 | return _se_resnet56(**kwargs) 14 | 15 | 16 | def se_resnet50(**kwargs): 17 | from senet.se_resnet import se_resnet50 as _se_resnet50 18 | 19 | return _se_resnet50(**kwargs) 20 | 21 | 22 | def se_resnet101(**kwargs): 23 | from senet.se_resnet import se_resnet101 as _se_resnet101 24 | 25 | return _se_resnet101(**kwargs) 26 | -------------------------------------------------------------------------------- /imagenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from homura import optim, lr_scheduler, callbacks, reporters, init_distributed 3 | from homura.trainers import SupervisedTrainer 4 | from homura.vision.data import imagenet_loaders 5 | from torch.nn import functional as F 6 | 7 | from senet.se_resnet import se_resnet50 8 | 9 | 10 | def main(): 11 | if args.distributed: 12 | init_distributed() 13 | 14 | model = se_resnet50(num_classes=1000) 15 | 16 | optimizer = optim.SGD(lr=0.6 / 1024 * args.batch_size, 17 | momentum=0.9, weight_decay=1e-4) 18 | scheduler = lr_scheduler.MultiStepLR([50, 70]) 19 | train_loader, test_loader = imagenet_loaders(args.root, args.batch_size, distributed=args.distributed, 20 | num_train_samples=args.batch_size * 10 if args.debug else None, 21 | num_test_samples=args.batch_size * 10 if args.debug else None) 22 | 23 | c = [callbacks.AccuracyCallback(), callbacks.AccuracyCallback(k=5), 24 | callbacks.LossCallback(), 25 | callbacks.WeightSave('.'), 26 | reporters.TensorboardReporter('.'), 27 | reporters.TQDMReporter(range(args.epochs))] 28 | 29 | with SupervisedTrainer(model, optimizer, F.cross_entropy, 30 | callbacks=c, 31 | scheduler=scheduler, 32 | ) as trainer: 33 | for _ in c[-1]: 34 | trainer.train(train_loader) 35 | trainer.test(test_loader) 36 | 37 | 38 | if __name__ == '__main__': 39 | import miniargs 40 | import warnings 41 | 42 | warnings.filterwarnings( 43 | "ignore", "(Possibly )?corrupt EXIF data", UserWarning) 44 | 45 | p = miniargs.ArgumentParser() 46 | p.add_str("root") 47 | p.add_int("--epochs", default=90) 48 | p.add_int("--batch_size", default=128) 49 | p.add_true("--distributed") 50 | p.add_int("--local_rank", default=-1) 51 | p.add_true("--debug", help="Use less images and less epochs") 52 | args, _else = p.parse(return_unknown=True) 53 | num_device = torch.cuda.device_count() 54 | 55 | print(args) 56 | if args.distributed and args.local_rank == -1: 57 | raise RuntimeError( 58 | f"For distributed training, use python -m torch.distributed.launch " 59 | f"--nproc_per_node={num_device} {__file__} {args.root} ...") 60 | main() 61 | -------------------------------------------------------------------------------- /senet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shualite/SFNet.pytorch/e25c5e2483ade8ab5cbf75d3d2e27bc00d1b3153/senet/__init__.py -------------------------------------------------------------------------------- /senet/baseline.py: -------------------------------------------------------------------------------- 1 | """ 2 | ResNet for CIFAR dataset proposed in He+15, p 7. and 3 | https://github.com/facebook/fb.resnet.torch/blob/master/models/resnet.lua 4 | """ 5 | 6 | import torch.nn as nn 7 | 8 | 9 | def conv3x3(in_planes, out_planes, stride=1): 10 | "3x3 convolution with padding" 11 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 12 | padding=1, bias=False) 13 | 14 | 15 | class BasicBlock(nn.Module): 16 | def __init__(self, inplanes, planes, stride=1): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | if inplanes != planes: 24 | self.downsample = nn.Sequential(nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False), 25 | nn.BatchNorm2d(planes)) 26 | else: 27 | self.downsample = lambda x: x 28 | self.stride = stride 29 | 30 | def forward(self, x): 31 | residual = self.downsample(x) 32 | out = self.conv1(x) 33 | out = self.bn1(out) 34 | out = self.relu(out) 35 | 36 | out = self.conv2(out) 37 | out = self.bn2(out) 38 | 39 | out += residual 40 | out = self.relu(out) 41 | 42 | return out 43 | 44 | 45 | class PreActBasicBlock(BasicBlock): 46 | def __init__(self, inplanes, planes, stride): 47 | super(PreActBasicBlock, self).__init__(inplanes, planes, stride) 48 | if inplanes != planes: 49 | self.downsample = nn.Sequential(nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)) 50 | else: 51 | self.downsample = lambda x: x 52 | self.bn1 = nn.BatchNorm2d(inplanes) 53 | 54 | def forward(self, x): 55 | residual = self.downsample(x) 56 | out = self.bn1(x) 57 | out = self.relu(out) 58 | out = self.conv1(out) 59 | 60 | out = self.bn2(out) 61 | out = self.relu(out) 62 | out = self.conv2(out) 63 | 64 | out += residual 65 | 66 | return out 67 | 68 | 69 | class ResNet(nn.Module): 70 | def __init__(self, block, n_size, num_classes=10): 71 | super(ResNet, self).__init__() 72 | self.inplane = 16 73 | self.conv1 = nn.Conv2d(3, self.inplane, kernel_size=3, stride=1, padding=1, bias=False) 74 | self.bn1 = nn.BatchNorm2d(self.inplane) 75 | self.relu = nn.ReLU(inplace=True) 76 | self.layer1 = self._make_layer(block, 16, blocks=n_size, stride=1) 77 | self.layer2 = self._make_layer(block, 32, blocks=n_size, stride=2) 78 | self.layer3 = self._make_layer(block, 64, blocks=n_size, stride=2) 79 | self.avgpool = nn.AdaptiveAvgPool2d(1) 80 | self.fc = nn.Linear(64, num_classes) 81 | 82 | self.initialize() 83 | 84 | def initialize(self): 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | nn.init.kaiming_normal_(m.weight) 88 | elif isinstance(m, nn.BatchNorm2d): 89 | nn.init.constant_(m.weight, 1) 90 | nn.init.constant_(m.bias, 0) 91 | 92 | def _make_layer(self, block, planes, blocks, stride): 93 | 94 | strides = [stride] + [1] * (blocks - 1) 95 | layers = [] 96 | for stride in strides: 97 | layers.append(block(self.inplane, planes, stride)) 98 | self.inplane = planes 99 | 100 | return nn.Sequential(*layers) 101 | 102 | def forward(self, x): 103 | x = self.conv1(x) 104 | x = self.bn1(x) 105 | x = self.relu(x) 106 | 107 | x = self.layer1(x) 108 | x = self.layer2(x) 109 | x = self.layer3(x) 110 | 111 | x = self.avgpool(x) 112 | x = x.view(x.size(0), -1) 113 | x = self.fc(x) 114 | 115 | return x 116 | 117 | 118 | class PreActResNet(ResNet): 119 | def __init__(self, block, n_size, num_classes=10): 120 | super(PreActResNet, self).__init__(block, n_size, num_classes) 121 | 122 | self.bn1 = nn.BatchNorm2d(self.inplane) 123 | self.initialize() 124 | 125 | def forward(self, x): 126 | x = self.conv1(x) 127 | x = self.layer1(x) 128 | x = self.layer2(x) 129 | x = self.layer3(x) 130 | 131 | x = self.bn1(x) 132 | x = self.relu(x) 133 | 134 | x = self.avgpool(x) 135 | x = x.view(x.size(0), -1) 136 | x = self.fc(x) 137 | 138 | return x 139 | 140 | 141 | def resnet20(**kwargs): 142 | model = ResNet(BasicBlock, 3, **kwargs) 143 | return model 144 | 145 | 146 | def resnet32(**kwargs): 147 | model = ResNet(BasicBlock, 5, **kwargs) 148 | return model 149 | 150 | 151 | def resnet56(**kwargs): 152 | model = ResNet(BasicBlock, 9, **kwargs) 153 | return model 154 | 155 | 156 | def resnet110(**kwargs): 157 | model = ResNet(BasicBlock, 18, **kwargs) 158 | return model 159 | 160 | 161 | def preact_resnet20(**kwargs): 162 | model = PreActResNet(PreActBasicBlock, 3, **kwargs) 163 | return model 164 | 165 | 166 | def preact_resnet32(**kwargs): 167 | model = PreActResNet(PreActBasicBlock, 5, **kwargs) 168 | return model 169 | 170 | 171 | def preact_resnet56(**kwargs): 172 | model = PreActResNet(PreActBasicBlock, 9, **kwargs) 173 | return model 174 | 175 | 176 | def preact_resnet110(**kwargs): 177 | model = PreActResNet(PreActBasicBlock, 18, **kwargs) 178 | return model 179 | -------------------------------------------------------------------------------- /senet/se_inception.py: -------------------------------------------------------------------------------- 1 | from senet.se_module import SELayer 2 | from torch import nn 3 | from torchvision.models.inception import Inception3 4 | 5 | 6 | class SEInception3(nn.Module): 7 | def __init__(self, num_classes, aux_logits=True, transform_input=False): 8 | super(SEInception3, self).__init__() 9 | model = Inception3(num_classes=num_classes, aux_logits=aux_logits, 10 | transform_input=transform_input) 11 | model.Mixed_5b.add_module("SELayer", SELayer(192)) 12 | model.Mixed_5c.add_module("SELayer", SELayer(256)) 13 | model.Mixed_5d.add_module("SELayer", SELayer(288)) 14 | model.Mixed_6a.add_module("SELayer", SELayer(288)) 15 | model.Mixed_6b.add_module("SELayer", SELayer(768)) 16 | model.Mixed_6c.add_module("SELayer", SELayer(768)) 17 | model.Mixed_6d.add_module("SELayer", SELayer(768)) 18 | model.Mixed_6e.add_module("SELayer", SELayer(768)) 19 | if aux_logits: 20 | model.AuxLogits.add_module("SELayer", SELayer(768)) 21 | model.Mixed_7a.add_module("SELayer", SELayer(768)) 22 | model.Mixed_7b.add_module("SELayer", SELayer(1280)) 23 | model.Mixed_7c.add_module("SELayer", SELayer(2048)) 24 | 25 | self.model = model 26 | 27 | def forward(self, x): 28 | _, _, h, w = x.size() 29 | if (h, w) != (299, 299): 30 | raise ValueError("input size must be (299, 299)") 31 | 32 | return self.model(x) 33 | 34 | 35 | def se_inception_v3(**kwargs): 36 | return SEInception3(**kwargs) 37 | -------------------------------------------------------------------------------- /senet/se_module.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class SELayer(nn.Module): 5 | def __init__(self, channel, reduction=16): 6 | super(SELayer, self).__init__() 7 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 8 | self.fc = nn.Sequential( 9 | nn.Linear(channel, channel // reduction, bias=False), 10 | nn.ReLU(inplace=True), 11 | nn.Linear(channel // reduction, channel, bias=False), 12 | nn.Sigmoid() 13 | ) 14 | 15 | def forward(self, x): 16 | b, c, _, _ = x.size() 17 | y = self.avg_pool(x).view(b, c) 18 | y = self.fc(y).view(b, c, 1, 1) 19 | return x * y.expand_as(x) 20 | -------------------------------------------------------------------------------- /senet/se_resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.hub import load_state_dict_from_url 3 | from torchvision.models import ResNet 4 | from senet.se_module import SELayer 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 9 | 10 | 11 | class SEBasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 15 | base_width=64, dilation=1, norm_layer=None, 16 | *, reduction=16): 17 | super(SEBasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv2 = conv3x3(planes, planes, 1) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.se = SELayer(planes, reduction) 24 | self.downsample = downsample 25 | self.stride = stride 26 | 27 | def forward(self, x): 28 | residual = x 29 | out = self.conv1(x) 30 | out = self.bn1(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv2(out) 34 | out = self.bn2(out) 35 | out = self.se(out) 36 | 37 | if self.downsample is not None: 38 | residual = self.downsample(x) 39 | 40 | out += residual 41 | out = self.relu(out) 42 | 43 | return out 44 | 45 | 46 | class SEBottleneck(nn.Module): 47 | expansion = 4 48 | 49 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 50 | base_width=64, dilation=1, norm_layer=None, 51 | *, reduction=16): 52 | super(SEBottleneck, self).__init__() 53 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 54 | self.bn1 = nn.BatchNorm2d(planes) 55 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 56 | padding=1, bias=False) 57 | self.bn2 = nn.BatchNorm2d(planes) 58 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 59 | self.bn3 = nn.BatchNorm2d(planes * 4) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.se = SELayer(planes * 4, reduction) 62 | self.downsample = downsample 63 | self.stride = stride 64 | 65 | def forward(self, x): 66 | residual = x 67 | 68 | out = self.conv1(x) 69 | out = self.bn1(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv2(out) 73 | out = self.bn2(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv3(out) 77 | out = self.bn3(out) 78 | out = self.se(out) 79 | 80 | if self.downsample is not None: 81 | residual = self.downsample(x) 82 | 83 | out += residual 84 | out = self.relu(out) 85 | 86 | return out 87 | 88 | 89 | def se_resnet18(num_classes=1_000): 90 | """Constructs a ResNet-18 model. 91 | 92 | Args: 93 | pretrained (bool): If True, returns a model pre-trained on ImageNet 94 | """ 95 | model = ResNet(SEBasicBlock, [2, 2, 2, 2], num_classes=num_classes) 96 | model.avgpool = nn.AdaptiveAvgPool2d(1) 97 | return model 98 | 99 | 100 | def se_resnet34(num_classes=1_000): 101 | """Constructs a ResNet-34 model. 102 | 103 | Args: 104 | pretrained (bool): If True, returns a model pre-trained on ImageNet 105 | """ 106 | model = ResNet(SEBasicBlock, [3, 4, 6, 3], num_classes=num_classes) 107 | model.avgpool = nn.AdaptiveAvgPool2d(1) 108 | return model 109 | 110 | 111 | def se_resnet50(num_classes=1_000, pretrained=False): 112 | """Constructs a ResNet-50 model. 113 | 114 | Args: 115 | pretrained (bool): If True, returns a model pre-trained on ImageNet 116 | """ 117 | model = ResNet(SEBottleneck, [3, 4, 6, 3], num_classes=num_classes) 118 | model.avgpool = nn.AdaptiveAvgPool2d(1) 119 | if pretrained: 120 | model.load_state_dict(load_state_dict_from_url( 121 | "https://github.com/moskomule/senet.pytorch/releases/download/archive/seresnet50-60a8950a85b2b.pkl")) 122 | return model 123 | 124 | 125 | def se_resnet101(num_classes=1_000): 126 | """Constructs a ResNet-101 model. 127 | 128 | Args: 129 | pretrained (bool): If True, returns a model pre-trained on ImageNet 130 | """ 131 | model = ResNet(SEBottleneck, [3, 4, 23, 3], num_classes=num_classes) 132 | model.avgpool = nn.AdaptiveAvgPool2d(1) 133 | return model 134 | 135 | 136 | def se_resnet152(num_classes=1_000): 137 | """Constructs a ResNet-152 model. 138 | 139 | Args: 140 | pretrained (bool): If True, returns a model pre-trained on ImageNet 141 | """ 142 | model = ResNet(SEBottleneck, [3, 8, 36, 3], num_classes=num_classes) 143 | model.avgpool = nn.AdaptiveAvgPool2d(1) 144 | return model 145 | 146 | 147 | class CifarSEBasicBlock(nn.Module): 148 | def __init__(self, inplanes, planes, stride=1, reduction=16): 149 | super(CifarSEBasicBlock, self).__init__() 150 | self.conv1 = conv3x3(inplanes, planes, stride) 151 | self.bn1 = nn.BatchNorm2d(planes) 152 | self.relu = nn.ReLU(inplace=True) 153 | self.conv2 = conv3x3(planes, planes) 154 | self.bn2 = nn.BatchNorm2d(planes) 155 | self.se = SELayer(planes, reduction) 156 | if inplanes != planes: 157 | self.downsample = nn.Sequential(nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False), 158 | nn.BatchNorm2d(planes)) 159 | else: 160 | self.downsample = lambda x: x 161 | self.stride = stride 162 | 163 | def forward(self, x): 164 | residual = self.downsample(x) 165 | out = self.conv1(x) 166 | out = self.bn1(out) 167 | out = self.relu(out) 168 | 169 | out = self.conv2(out) 170 | out = self.bn2(out) 171 | out = self.se(out) 172 | 173 | out += residual 174 | out = self.relu(out) 175 | 176 | return out 177 | 178 | 179 | class CifarSEResNet(nn.Module): 180 | def __init__(self, block, n_size, num_classes=10, reduction=16): 181 | super(CifarSEResNet, self).__init__() 182 | self.inplane = 16 183 | self.conv1 = nn.Conv2d( 184 | 3, self.inplane, kernel_size=3, stride=1, padding=1, bias=False) 185 | self.bn1 = nn.BatchNorm2d(self.inplane) 186 | self.relu = nn.ReLU(inplace=True) 187 | self.layer1 = self._make_layer( 188 | block, 16, blocks=n_size, stride=1, reduction=reduction) 189 | self.layer2 = self._make_layer( 190 | block, 32, blocks=n_size, stride=2, reduction=reduction) 191 | self.layer3 = self._make_layer( 192 | block, 64, blocks=n_size, stride=2, reduction=reduction) 193 | self.avgpool = nn.AdaptiveAvgPool2d(1) 194 | self.fc = nn.Linear(64, num_classes) 195 | self.initialize() 196 | 197 | def initialize(self): 198 | for m in self.modules(): 199 | if isinstance(m, nn.Conv2d): 200 | nn.init.kaiming_normal_(m.weight) 201 | elif isinstance(m, nn.BatchNorm2d): 202 | nn.init.constant_(m.weight, 1) 203 | nn.init.constant_(m.bias, 0) 204 | 205 | def _make_layer(self, block, planes, blocks, stride, reduction): 206 | strides = [stride] + [1] * (blocks - 1) 207 | layers = [] 208 | for stride in strides: 209 | layers.append(block(self.inplane, planes, stride, reduction)) 210 | self.inplane = planes 211 | 212 | return nn.Sequential(*layers) 213 | 214 | def forward(self, x): 215 | x = self.conv1(x) 216 | x = self.bn1(x) 217 | x = self.relu(x) 218 | 219 | x = self.layer1(x) 220 | x = self.layer2(x) 221 | x = self.layer3(x) 222 | 223 | x = self.avgpool(x) 224 | x = x.view(x.size(0), -1) 225 | x = self.fc(x) 226 | 227 | return x 228 | 229 | 230 | class CifarSEPreActResNet(CifarSEResNet): 231 | def __init__(self, block, n_size, num_classes=10, reduction=16): 232 | super(CifarSEPreActResNet, self).__init__( 233 | block, n_size, num_classes, reduction) 234 | self.bn1 = nn.BatchNorm2d(self.inplane) 235 | self.initialize() 236 | 237 | def forward(self, x): 238 | x = self.conv1(x) 239 | x = self.layer1(x) 240 | x = self.layer2(x) 241 | x = self.layer3(x) 242 | 243 | x = self.bn1(x) 244 | x = self.relu(x) 245 | 246 | x = self.avgpool(x) 247 | x = x.view(x.size(0), -1) 248 | x = self.fc(x) 249 | 250 | 251 | def se_resnet20(**kwargs): 252 | """Constructs a ResNet-18 model. 253 | 254 | """ 255 | model = CifarSEResNet(CifarSEBasicBlock, 3, **kwargs) 256 | return model 257 | 258 | 259 | def se_resnet32(**kwargs): 260 | """Constructs a ResNet-34 model. 261 | 262 | """ 263 | model = CifarSEResNet(CifarSEBasicBlock, 5, **kwargs) 264 | return model 265 | 266 | 267 | def se_resnet56(**kwargs): 268 | """Constructs a ResNet-34 model. 269 | 270 | """ 271 | model = CifarSEResNet(CifarSEBasicBlock, 9, **kwargs) 272 | return model 273 | 274 | 275 | def se_preactresnet20(**kwargs): 276 | """Constructs a ResNet-18 model. 277 | 278 | """ 279 | model = CifarSEPreActResNet(CifarSEBasicBlock, 3, **kwargs) 280 | return model 281 | 282 | 283 | def se_preactresnet32(**kwargs): 284 | """Constructs a ResNet-34 model. 285 | 286 | """ 287 | model = CifarSEPreActResNet(CifarSEBasicBlock, 5, **kwargs) 288 | return model 289 | 290 | 291 | def se_preactresnet56(**kwargs): 292 | """Constructs a ResNet-34 model. 293 | 294 | """ 295 | model = CifarSEPreActResNet(CifarSEBasicBlock, 9, **kwargs) 296 | return model 297 | -------------------------------------------------------------------------------- /sfnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Shualite/SFNet.pytorch/e25c5e2483ade8ab5cbf75d3d2e27bc00d1b3153/sfnet/__init__.py -------------------------------------------------------------------------------- /sfnet/sfnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch.nn.functional as F 5 | import torch 6 | 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | 'resnet152'] 10 | 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3mb4d8f.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | } 19 | 20 | 21 | def conv3x3(in_planes, out_planes, stride=1): 22 | """3x3 convolution with padding""" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 24 | padding=1, bias=False) 25 | 26 | 27 | class BasicBlock(nn.Module): 28 | expansion = 1 29 | 30 | def __init__(self, inplanes, planes, stride=1, downsample=None): 31 | super(BasicBlock, self).__init__() 32 | self.conv1 = conv3x3(inplanes, planes, stride) 33 | self.bn1 = nn.BatchNorm2d(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes) 36 | self.bn2 = nn.BatchNorm2d(planes) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | residual = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | 50 | if self.downsample is not None: 51 | residual = self.downsample(x) 52 | 53 | out += residual 54 | out = self.relu(out) 55 | 56 | return out 57 | 58 | 59 | class Bottleneck(nn.Module): 60 | expansion = 4 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None): 63 | super(Bottleneck, self).__init__() 64 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 65 | self.bn1 = nn.BatchNorm2d(planes) 66 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 67 | padding=1, bias=False) 68 | self.bn2 = nn.BatchNorm2d(planes) 69 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 70 | self.bn3 = nn.BatchNorm2d(planes * 4) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv3(out) 87 | out = self.bn3(out) 88 | 89 | if self.downsample is not None: 90 | residual = self.downsample(x) 91 | 92 | out += residual 93 | out = self.relu(out) 94 | 95 | return out 96 | 97 | 98 | class SFNet_ResNet(nn.Module): 99 | 100 | def __init__(self, block, layers, num_classes=7, scale=1): 101 | self.inplanes = 64 102 | super(SFNet_ResNet, self).__init__() 103 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 104 | bias=False) 105 | self.bn1 = nn.BatchNorm2d(64) 106 | self.relu1 = nn.ReLU(inplace=True) 107 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 108 | self.layer1 = self._make_layer(block, 64, layers[0]) 109 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 110 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 111 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 112 | # self.avgpool = nn.AvgPool2d(7, stride=1) 113 | # self.fc = nn.Linear(512 * block.expansion, num_classes) 114 | 115 | # Top layer 116 | self.toplayer = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0) # Reduce channels 117 | self.toplayer_bn = nn.BatchNorm2d(256) 118 | self.toplayer_relu = nn.ReLU(inplace=True) 119 | 120 | # Smooth layers 121 | self.smooth1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 122 | self.smooth1_bn = nn.BatchNorm2d(256) 123 | self.smooth1_relu = nn.ReLU(inplace=True) 124 | 125 | self.smooth2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 126 | self.smooth2_bn = nn.BatchNorm2d(256) 127 | self.smooth2_relu = nn.ReLU(inplace=True) 128 | 129 | self.smooth3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 130 | self.smooth3_bn = nn.BatchNorm2d(256) 131 | self.smooth3_relu = nn.ReLU(inplace=True) 132 | 133 | # Lateral layers 134 | self.latlayer1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) 135 | self.latlayer1_bn = nn.BatchNorm2d(256) 136 | self.latlayer1_relu = nn.ReLU(inplace=True) 137 | 138 | self.latlayer2 = nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0) 139 | self.latlayer2_bn = nn.BatchNorm2d(256) 140 | self.latlayer2_relu = nn.ReLU(inplace=True) 141 | 142 | self.latlayer3 = nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0) 143 | self.latlayer3_bn = nn.BatchNorm2d(256) 144 | self.latlayer3_relu = nn.ReLU(inplace=True) 145 | 146 | # flow layer 147 | self.flowconv1 = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 148 | self.flowconv2 = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 149 | self.flowconv3 = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) 150 | 151 | self.conv2 = nn.Conv2d(1024, 256, kernel_size=3, stride=1, padding=1) 152 | self.bn2 = nn.BatchNorm2d(256) 153 | self.relu2 = nn.ReLU(inplace=True) 154 | self.conv3 = nn.Conv2d(256, num_classes, kernel_size=1, stride=1, padding=0) 155 | 156 | self.scale = scale 157 | 158 | for m in self.modules(): 159 | if isinstance(m, nn.Conv2d): 160 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 161 | m.weight.data.normal_(0, math.sqrt(2. / n)) 162 | elif isinstance(m, nn.BatchNorm2d): 163 | m.weight.data.fill_(1) 164 | m.bias.data.zero_() 165 | 166 | def _make_layer(self, block, planes, blocks, stride=1): 167 | downsample = None 168 | if stride != 1 or self.inplanes != planes * block.expansion: 169 | downsample = nn.Sequential( 170 | nn.Conv2d(self.inplanes, planes * block.expansion, 171 | kernel_size=1, stride=stride, bias=False), 172 | nn.BatchNorm2d(planes * block.expansion), 173 | ) 174 | 175 | layers = [] 176 | layers.append(block(self.inplanes, planes, stride, downsample)) 177 | self.inplanes = planes * block.expansion 178 | for i in range(1, blocks): 179 | layers.append(block(self.inplanes, planes)) 180 | 181 | return nn.Sequential(*layers) 182 | 183 | def _upsample(self, x, y, scale=1): 184 | _, _, H, W = y.size() 185 | return F.upsample(x, size=(H // scale, W // scale), mode='bilinear') 186 | 187 | def _upsample_add(self, x, y): 188 | _, _, H, W = y.size() 189 | return F.upsample(x, size=(H, W), mode='bilinear') + y 190 | 191 | # Semantic Flow for Fast and Accurate Scene Parsing arXiv:2002.10120v1 192 | # Flow Align Module 193 | def _flow_align_module(self, featmap_front, featmap_latter, func): 194 | B, C, H, W = featmap_latter.size() 195 | fuse = torch.cat((featmap_front, self._upsample(featmap_latter, featmap_front)), 1) 196 | 197 | flow = func(fuse) 198 | flow = self._upsample(flow, featmap_latter) 199 | flow = flow.permute(0, 2, 3, 1) 200 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) 201 | grid = torch.stack((grid_x, grid_y), 2).float() 202 | grid.requires_grad = False 203 | grid = grid.type_as(featmap_latter) 204 | vgrid = grid + flow 205 | # scale grid to [-1, 1] 206 | vgrid_x = 2.0 * vgrid[:,:,:,0] / max(W-1, 1) - 1.0 207 | vgrid_y = 2.0 * vgrid[:,:,:,1] / max(H-1, 1) - 1.0 208 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) 209 | output = F.grid_sample(featmap_latter, vgrid_scaled, mode='bilinear', padding_mode='zeros') 210 | return output 211 | 212 | def forward(self, x): 213 | h = x 214 | h = self.conv1(h) 215 | h = self.bn1(h) 216 | h = self.relu1(h) 217 | h = self.maxpool(h) 218 | 219 | h = self.layer1(h) 220 | c2 = h 221 | h = self.layer2(h) 222 | c3 = h 223 | h = self.layer3(h) 224 | c4 = h 225 | h = self.layer4(h) 226 | c5 = h 227 | 228 | # Top-down 229 | p5 = self.toplayer(c5) 230 | p5 = self.toplayer_relu(self.toplayer_bn(p5)) 231 | 232 | c4 = self.latlayer1(c4) 233 | c4 = self.latlayer1_relu(self.latlayer1_bn(c4)) 234 | p5_flow = self._flow_align_module(c4, p5, self.flowconv1) 235 | p4 = self._upsample_add(p5_flow, c4) 236 | p4 = self.smooth1(p4) 237 | p4 = self.smooth1_relu(self.smooth1_bn(p4)) 238 | 239 | c3 = self.latlayer2(c3) 240 | c3 = self.latlayer2_relu(self.latlayer2_bn(c3)) 241 | p4_flow = self._flow_align_module(c3, p4, self.flowconv2) 242 | p3 = self._upsample_add(p4_flow, c3) 243 | p3 = self.smooth2(p3) 244 | p3 = self.smooth2_relu(self.smooth2_bn(p3)) 245 | 246 | c2 = self.latlayer3(c2) 247 | c2 = self.latlayer3_relu(self.latlayer3_bn(c2)) 248 | p3_flow = self._flow_align_module(c2, p3, self.flowconv3) 249 | p2 = self._upsample_add(p3_flow, c2) 250 | p2 = self.smooth3(p2) 251 | p2 = self.smooth3_relu(self.smooth3_bn(p2)) 252 | 253 | p3 = self._upsample(p3, p2) 254 | p4 = self._upsample(p4, p2) 255 | p5 = self._upsample(p5, p2) 256 | 257 | out = torch.cat((p2, p3, p4, p5), 1) 258 | out = self.conv2(out) 259 | out = self.relu2(self.bn2(out)) 260 | out = self.conv3(out) 261 | out = self._upsample(out, x, scale=self.scale) 262 | 263 | return out 264 | 265 | 266 | def sf_resnet18(pretrained=False, **kwargs): 267 | """Constructs a ResNet-18 model. 268 | 269 | Args: 270 | pretrained (bool): If True, returns a model pre-trained on ImageNet 271 | """ 272 | model = SFNet_ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 273 | if pretrained: 274 | model.load_state_dict(model_zoo.load_url(model_urls['sf_resnet18'])) 275 | return model 276 | 277 | 278 | def sf_resnet34(pretrained=False, **kwargs): 279 | """Constructs a ResNet-34 model. 280 | 281 | Args: 282 | pretrained (bool): If True, returns a model pre-trained on ImageNet 283 | """ 284 | model = SFNet_ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 285 | if pretrained: 286 | model.load_state_dict(model_zoo.load_url(model_urls['sf_resnet34'])) 287 | return model 288 | 289 | 290 | def sf_resnet50(pretrained=False, **kwargs): 291 | """Constructs a ResNet-50 model. 292 | 293 | Args: 294 | pretrained (bool): If True, returns a model pre-trained on ImageNet 295 | """ 296 | model = SFNet_ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 297 | if pretrained: 298 | pretrained_model = model_zoo.load_url(model_urls['sf_resnet50']) 299 | state = model.state_dict() 300 | for key in state.keys(): 301 | if key in pretrained_model.keys(): 302 | state[key] = pretrained_model[key] 303 | model.load_state_dict(state) 304 | return model 305 | 306 | 307 | def sf_resnet101(pretrained=False, **kwargs): 308 | """Constructs a ResNet-101 model. 309 | 310 | Args: 311 | pretrained (bool): If True, returns a model pre-trained on ImageNet 312 | """ 313 | model = SFNet_ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 314 | if pretrained: 315 | pretrained_model = model_zoo.load_url(model_urls['sf_resnet101']) 316 | state = model.state_dict() 317 | for key in state.keys(): 318 | if key in pretrained_model.keys(): 319 | state[key] = pretrained_model[key] 320 | model.load_state_dict(state) 321 | return model 322 | 323 | def sf_resnet152(pretrained=False, **kwargs): 324 | """Constructs a ResNet-152 model. 325 | 326 | Args: 327 | pretrained (bool): If True, returns a model pre-trained on ImageNet 328 | """ 329 | model = SFNet_ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 330 | if pretrained: 331 | pretrained_model = model_zoo.load_url(model_urls['sf_resnet152']) 332 | state = model.state_dict() 333 | for key in state.keys(): 334 | if key in pretrained_model.keys(): 335 | state[key] = pretrained_model[key] 336 | model.load_state_dict(state) 337 | return model 338 | --------------------------------------------------------------------------------