├── .gitignore ├── README.md ├── example.py └── models ├── resnet_2d.py ├── senet.py ├── stnet.py └── temporal_xception.py /.gitignore: -------------------------------------------------------------------------------- 1 | # git ls-files --others --exclude-from=.git/info/exclude 2 | # Lines that start with '#' are comments. 3 | # For a project mostly in C, the following would be a good set of 4 | # exclude patterns (uncomment them if you want to use them): 5 | # *.[oa] 6 | # *~ 7 | __pycache__ 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StNet 2 | This git holds the architecture of the network exposed in this paper: https://arxiv.org/pdf/1811.01549.pdf 3 | 4 | It felt very inspired, but the team didn't release their code. I'm not among the authors but I felt eager to test it so I designed it following the instructions released in the paper. The actual implementation might differ. 5 | 6 | This work is largely inspired by those codes : 7 | 8 | ### Squeeze and Excite Resnet 9 | 10 | https://github.com/hujie-frank/SENet 11 | 12 | ### Traditionnal Resnet 13 | 14 | https://github.com/pytorch/vision.git 15 | 16 | ### Xception 17 | 18 | https://github.com/tstandley/Xception-PyTorch.git 19 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | from models import stnet 2 | 3 | 4 | model = stnet.stnet50(input_channels=3, num_classes=400, T=7, N=5) 5 | -------------------------------------------------------------------------------- /models/resnet_2d.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | 5 | __all__ = ["ResNet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152"] 6 | 7 | 8 | model_urls = { 9 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", 10 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", 11 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", 12 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", 13 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | """3x3 convolution with padding""" 19 | return nn.Conv2d( 20 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 21 | ) 22 | 23 | 24 | def conv1x1(in_planes, out_planes, stride=1): 25 | """1x1 convolution""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None): 33 | super(BasicBlock, self).__init__() 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None): 65 | super(Bottleneck, self).__init__() 66 | self.conv1 = conv1x1(inplanes, planes) 67 | self.bn1 = nn.BatchNorm2d(planes) 68 | self.conv2 = conv3x3(planes, planes, stride) 69 | self.bn2 = nn.BatchNorm2d(planes) 70 | self.conv3 = conv1x1(planes, planes * self.expansion) 71 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.inplanes = inplanes 75 | self.planes = planes 76 | self.stride = stride 77 | 78 | def forward(self, x): 79 | residual = x 80 | 81 | out = self.conv1(x) 82 | out = self.bn1(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv3(out) 90 | out = self.bn3(out) 91 | 92 | if self.downsample is not None: 93 | residual = self.downsample(x) 94 | 95 | out += residual 96 | out = self.relu(out) 97 | 98 | return out 99 | 100 | 101 | class ResNet(nn.Module): 102 | def __init__(self, block, layers, num_classes=1000): 103 | self.inplanes = 64 104 | super(ResNet, self).__init__() 105 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 106 | self.bn1 = nn.BatchNorm2d(64) 107 | self.relu = nn.ReLU(inplace=True) 108 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 109 | self.layer1 = self._make_layer(block, 64, layers[0]) 110 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 111 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 112 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 113 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 114 | self.fc = nn.Linear(512 * block.expansion, num_classes) 115 | 116 | for m in self.modules(): 117 | if isinstance(m, nn.Conv2d): 118 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 119 | elif isinstance(m, nn.BatchNorm2d): 120 | nn.init.constant_(m.weight, 1) 121 | nn.init.constant_(m.bias, 0) 122 | 123 | def _make_layer(self, block, planes, blocks, stride=1): 124 | downsample = None 125 | if stride != 1 or self.inplanes != planes * block.expansion: 126 | downsample = nn.Sequential( 127 | conv1x1(self.inplanes, planes * block.expansion, stride), 128 | nn.BatchNorm2d(planes * block.expansion), 129 | ) 130 | 131 | layers = [] 132 | layers.append(block(self.inplanes, planes, stride, downsample)) 133 | self.inplanes = planes * block.expansion 134 | for _ in range(1, blocks): 135 | layers.append(block(self.inplanes, planes)) 136 | 137 | return nn.Sequential(*layers) 138 | 139 | def forward(self, x): 140 | x = self.conv1(x) 141 | x = self.bn1(x) 142 | x = self.relu(x) 143 | x = self.maxpool(x) 144 | 145 | x = self.layer1(x) 146 | x = self.layer2(x) 147 | x = self.layer3(x) 148 | x = self.layer4(x) 149 | 150 | x = self.avgpool(x) 151 | x = x.view(x.size(0), -1) 152 | x = self.fc(x) 153 | 154 | return x 155 | 156 | 157 | def resnet18(pretrained=False, **kwargs): 158 | """Constructs a ResNet-18 model. 159 | 160 | Args: 161 | pretrained (bool): If True, returns a model pre-trained on ImageNet 162 | """ 163 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 164 | if pretrained: 165 | model.load_state_dict(model_zoo.load_url(model_urls["resnet18"])) 166 | return model 167 | 168 | 169 | def resnet34(pretrained=False, **kwargs): 170 | """Constructs a ResNet-34 model. 171 | 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 176 | if pretrained: 177 | model.load_state_dict(model_zoo.load_url(model_urls["resnet34"])) 178 | return model 179 | 180 | 181 | def resnet50(pretrained=False, **kwargs): 182 | """Constructs a ResNet-50 model. 183 | 184 | Args: 185 | pretrained (bool): If True, returns a model pre-trained on ImageNet 186 | """ 187 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 188 | if pretrained: 189 | model.load_state_dict(model_zoo.load_url(model_urls["resnet50"])) 190 | return model 191 | 192 | 193 | def resnet101(pretrained=False, **kwargs): 194 | """Constructs a ResNet-101 model. 195 | 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on ImageNet 198 | """ 199 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 200 | if pretrained: 201 | model.load_state_dict(model_zoo.load_url(model_urls["resnet101"])) 202 | return model 203 | 204 | 205 | def resnet152(pretrained=False, **kwargs): 206 | """Constructs a ResNet-152 model. 207 | 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on ImageNet 210 | """ 211 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 212 | if pretrained: 213 | model.load_state_dict(model_zoo.load_url(model_urls["resnet152"])) 214 | return model 215 | -------------------------------------------------------------------------------- /models/senet.py: -------------------------------------------------------------------------------- 1 | """ 2 | ResNet code gently borrowed from 3 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 4 | """ 5 | from __future__ import print_function, division, absolute_import 6 | from collections import OrderedDict 7 | import math 8 | 9 | import torch.nn as nn 10 | from torch.utils import model_zoo 11 | 12 | __all__ = [ 13 | "SENet", 14 | "senet154", 15 | "se_resnet50", 16 | "se_resnet101", 17 | "se_resnet152", 18 | "se_resnext50_32x4d", 19 | "se_resnext101_32x4d", 20 | ] 21 | 22 | pretrained_settings = { 23 | "senet154": { 24 | "imagenet": { 25 | "url": "http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth", 26 | "input_space": "RGB", 27 | "input_size": [3, 224, 224], 28 | "input_range": [0, 1], 29 | "mean": [0.485, 0.456, 0.406], 30 | "std": [0.229, 0.224, 0.225], 31 | "num_classes": 1000, 32 | } 33 | }, 34 | "se_resnet50": { 35 | "imagenet": { 36 | "url": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth", 37 | "input_space": "RGB", 38 | "input_size": [3, 224, 224], 39 | "input_range": [0, 1], 40 | "mean": [0.485, 0.456, 0.406], 41 | "std": [0.229, 0.224, 0.225], 42 | "num_classes": 1000, 43 | } 44 | }, 45 | "se_resnet101": { 46 | "imagenet": { 47 | "url": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth", 48 | "input_space": "RGB", 49 | "input_size": [3, 224, 224], 50 | "input_range": [0, 1], 51 | "mean": [0.485, 0.456, 0.406], 52 | "std": [0.229, 0.224, 0.225], 53 | "num_classes": 1000, 54 | } 55 | }, 56 | "se_resnet152": { 57 | "imagenet": { 58 | "url": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth", 59 | "input_space": "RGB", 60 | "input_size": [3, 224, 224], 61 | "input_range": [0, 1], 62 | "mean": [0.485, 0.456, 0.406], 63 | "std": [0.229, 0.224, 0.225], 64 | "num_classes": 1000, 65 | } 66 | }, 67 | "se_resnext50_32x4d": { 68 | "imagenet": { 69 | "url": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth", 70 | "input_space": "RGB", 71 | "input_size": [3, 224, 224], 72 | "input_range": [0, 1], 73 | "mean": [0.485, 0.456, 0.406], 74 | "std": [0.229, 0.224, 0.225], 75 | "num_classes": 1000, 76 | } 77 | }, 78 | "se_resnext101_32x4d": { 79 | "imagenet": { 80 | "url": "http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth", 81 | "input_space": "RGB", 82 | "input_size": [3, 224, 224], 83 | "input_range": [0, 1], 84 | "mean": [0.485, 0.456, 0.406], 85 | "std": [0.229, 0.224, 0.225], 86 | "num_classes": 1000, 87 | } 88 | }, 89 | } 90 | 91 | 92 | class SEModule(nn.Module): 93 | def __init__(self, channels, reduction): 94 | super(SEModule, self).__init__() 95 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 96 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0) 97 | self.relu = nn.ReLU(inplace=True) 98 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0) 99 | self.sigmoid = nn.Sigmoid() 100 | 101 | def forward(self, x): 102 | module_input = x 103 | x = self.avg_pool(x) 104 | x = self.fc1(x) 105 | x = self.relu(x) 106 | x = self.fc2(x) 107 | x = self.sigmoid(x) 108 | return module_input * x 109 | 110 | 111 | class Bottleneck(nn.Module): 112 | """ 113 | Base class for bottlenecks that implements `forward()` method. 114 | """ 115 | 116 | def forward(self, x): 117 | residual = x 118 | 119 | out = self.conv1(x) 120 | out = self.bn1(out) 121 | out = self.relu(out) 122 | 123 | out = self.conv2(out) 124 | out = self.bn2(out) 125 | out = self.relu(out) 126 | 127 | out = self.conv3(out) 128 | out = self.bn3(out) 129 | 130 | if self.downsample is not None: 131 | residual = self.downsample(x) 132 | 133 | out = self.se_module(out) + residual 134 | out = self.relu(out) 135 | 136 | return out 137 | 138 | 139 | class SEBottleneck(Bottleneck): 140 | """ 141 | Bottleneck for SENet154. 142 | """ 143 | 144 | expansion = 4 145 | 146 | def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): 147 | super(SEBottleneck, self).__init__() 148 | self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) 149 | self.bn1 = nn.BatchNorm2d(planes * 2) 150 | self.conv2 = nn.Conv2d( 151 | planes * 2, 152 | planes * 4, 153 | kernel_size=3, 154 | stride=stride, 155 | padding=1, 156 | groups=groups, 157 | bias=False, 158 | ) 159 | self.bn2 = nn.BatchNorm2d(planes * 4) 160 | self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, bias=False) 161 | self.bn3 = nn.BatchNorm2d(planes * 4) 162 | self.relu = nn.ReLU(inplace=True) 163 | self.se_module = SEModule(planes * 4, reduction=reduction) 164 | self.downsample = downsample 165 | self.stride = stride 166 | 167 | 168 | class SEResNetBottleneck(Bottleneck): 169 | """ 170 | ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe 171 | implementation and uses `stride=stride` in `conv1` and not in `conv2` 172 | (the latter is used in the torchvision implementation of ResNet). 173 | """ 174 | 175 | expansion = 4 176 | 177 | def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): 178 | super(SEResNetBottleneck, self).__init__() 179 | self.conv1 = nn.Conv2d( 180 | inplanes, planes, kernel_size=1, bias=False, stride=stride 181 | ) 182 | self.bn1 = nn.BatchNorm2d(planes) 183 | self.conv2 = nn.Conv2d( 184 | planes, planes, kernel_size=3, padding=1, groups=groups, bias=False 185 | ) 186 | self.bn2 = nn.BatchNorm2d(planes) 187 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 188 | self.bn3 = nn.BatchNorm2d(planes * 4) 189 | self.relu = nn.ReLU(inplace=True) 190 | self.se_module = SEModule(planes * 4, reduction=reduction) 191 | self.downsample = downsample 192 | self.stride = stride 193 | 194 | 195 | class SEResNeXtBottleneck(Bottleneck): 196 | """ 197 | ResNeXt bottleneck type C with a Squeeze-and-Excitation module. 198 | """ 199 | 200 | expansion = 4 201 | 202 | def __init__( 203 | self, 204 | inplanes, 205 | planes, 206 | groups, 207 | reduction, 208 | stride=1, 209 | downsample=None, 210 | base_width=4, 211 | ): 212 | super(SEResNeXtBottleneck, self).__init__() 213 | width = math.floor(planes * (base_width / 64)) * groups 214 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, stride=1) 215 | self.bn1 = nn.BatchNorm2d(width) 216 | self.conv2 = nn.Conv2d( 217 | width, 218 | width, 219 | kernel_size=3, 220 | stride=stride, 221 | padding=1, 222 | groups=groups, 223 | bias=False, 224 | ) 225 | self.bn2 = nn.BatchNorm2d(width) 226 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) 227 | self.bn3 = nn.BatchNorm2d(planes * 4) 228 | self.relu = nn.ReLU(inplace=True) 229 | self.se_module = SEModule(planes * 4, reduction=reduction) 230 | self.downsample = downsample 231 | self.stride = stride 232 | 233 | 234 | class SENet(nn.Module): 235 | def __init__( 236 | self, 237 | block, 238 | layers, 239 | groups, 240 | reduction, 241 | dropout_p=0.2, 242 | inplanes=128, 243 | input_3x3=True, 244 | downsample_kernel_size=3, 245 | downsample_padding=1, 246 | num_classes=1000, 247 | ): 248 | """ 249 | Parameters 250 | ---------- 251 | block (nn.Module): Bottleneck class. 252 | - For SENet154: SEBottleneck 253 | - For SE-ResNet models: SEResNetBottleneck 254 | - For SE-ResNeXt models: SEResNeXtBottleneck 255 | layers (list of ints): Number of residual blocks for 4 layers of the 256 | network (layer1...layer4). 257 | groups (int): Number of groups for the 3x3 convolution in each 258 | bottleneck block. 259 | - For SENet154: 64 260 | - For SE-ResNet models: 1 261 | - For SE-ResNeXt models: 32 262 | reduction (int): Reduction ratio for Squeeze-and-Excitation modules. 263 | - For all models: 16 264 | dropout_p (float or None): Drop probability for the Dropout layer. 265 | If `None` the Dropout layer is not used. 266 | - For SENet154: 0.2 267 | - For SE-ResNet models: None 268 | - For SE-ResNeXt models: None 269 | inplanes (int): Number of input channels for layer1. 270 | - For SENet154: 128 271 | - For SE-ResNet models: 64 272 | - For SE-ResNeXt models: 64 273 | input_3x3 (bool): If `True`, use three 3x3 convolutions instead of 274 | a single 7x7 convolution in layer0. 275 | - For SENet154: True 276 | - For SE-ResNet models: False 277 | - For SE-ResNeXt models: False 278 | downsample_kernel_size (int): Kernel size for downsampling convolutions 279 | in layer2, layer3 and layer4. 280 | - For SENet154: 3 281 | - For SE-ResNet models: 1 282 | - For SE-ResNeXt models: 1 283 | downsample_padding (int): Padding for downsampling convolutions in 284 | layer2, layer3 and layer4. 285 | - For SENet154: 1 286 | - For SE-ResNet models: 0 287 | - For SE-ResNeXt models: 0 288 | num_classes (int): Number of outputs in `last_linear` layer. 289 | - For all models: 1000 290 | """ 291 | super(SENet, self).__init__() 292 | self.inplanes = inplanes 293 | if input_3x3: 294 | layer0_modules = [ 295 | ("conv1", nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False)), 296 | ("bn1", nn.BatchNorm2d(64)), 297 | ("relu1", nn.ReLU(inplace=True)), 298 | ("conv2", nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)), 299 | ("bn2", nn.BatchNorm2d(64)), 300 | ("relu2", nn.ReLU(inplace=True)), 301 | ("conv3", nn.Conv2d(64, inplanes, 3, stride=1, padding=1, bias=False)), 302 | ("bn3", nn.BatchNorm2d(inplanes)), 303 | ("relu3", nn.ReLU(inplace=True)), 304 | ] 305 | else: 306 | layer0_modules = [ 307 | ( 308 | "conv1", 309 | nn.Conv2d( 310 | 3, inplanes, kernel_size=7, stride=2, padding=3, bias=False 311 | ), 312 | ), 313 | ("bn1", nn.BatchNorm2d(inplanes)), 314 | ("relu1", nn.ReLU(inplace=True)), 315 | ] 316 | # To preserve compatibility with Caffe weights `ceil_mode=True` 317 | # is used instead of `padding=1`. 318 | layer0_modules.append(("pool", nn.MaxPool2d(3, stride=2, ceil_mode=True))) 319 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) 320 | self.layer1 = self._make_layer( 321 | block, 322 | planes=64, 323 | blocks=layers[0], 324 | groups=groups, 325 | reduction=reduction, 326 | downsample_kernel_size=1, 327 | downsample_padding=0, 328 | ) 329 | self.layer2 = self._make_layer( 330 | block, 331 | planes=128, 332 | blocks=layers[1], 333 | stride=2, 334 | groups=groups, 335 | reduction=reduction, 336 | downsample_kernel_size=downsample_kernel_size, 337 | downsample_padding=downsample_padding, 338 | ) 339 | self.layer3 = self._make_layer( 340 | block, 341 | planes=256, 342 | blocks=layers[2], 343 | stride=2, 344 | groups=groups, 345 | reduction=reduction, 346 | downsample_kernel_size=downsample_kernel_size, 347 | downsample_padding=downsample_padding, 348 | ) 349 | self.layer4 = self._make_layer( 350 | block, 351 | planes=512, 352 | blocks=layers[3], 353 | stride=2, 354 | groups=groups, 355 | reduction=reduction, 356 | downsample_kernel_size=downsample_kernel_size, 357 | downsample_padding=downsample_padding, 358 | ) 359 | self.avg_pool = nn.AvgPool2d(7, stride=1) 360 | self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None 361 | self.last_linear = nn.Linear(512 * block.expansion, num_classes) 362 | 363 | def _make_layer( 364 | self, 365 | block, 366 | planes, 367 | blocks, 368 | groups, 369 | reduction, 370 | stride=1, 371 | downsample_kernel_size=1, 372 | downsample_padding=0, 373 | ): 374 | downsample = None 375 | if stride != 1 or self.inplanes != planes * block.expansion: 376 | downsample = nn.Sequential( 377 | nn.Conv2d( 378 | self.inplanes, 379 | planes * block.expansion, 380 | kernel_size=downsample_kernel_size, 381 | stride=stride, 382 | padding=downsample_padding, 383 | bias=False, 384 | ), 385 | nn.BatchNorm2d(planes * block.expansion), 386 | ) 387 | 388 | layers = [] 389 | layers.append( 390 | block(self.inplanes, planes, groups, reduction, stride, downsample) 391 | ) 392 | self.inplanes = planes * block.expansion 393 | for i in range(1, blocks): 394 | layers.append(block(self.inplanes, planes, groups, reduction)) 395 | 396 | return nn.Sequential(*layers) 397 | 398 | def features(self, x): 399 | x = self.layer0(x) 400 | x = self.layer1(x) 401 | x = self.layer2(x) 402 | x = self.layer3(x) 403 | x = self.layer4(x) 404 | return x 405 | 406 | def logits(self, x): 407 | x = self.avg_pool(x) 408 | if self.dropout is not None: 409 | x = self.dropout(x) 410 | x = x.view(x.size(0), -1) 411 | x = self.last_linear(x) 412 | return x 413 | 414 | def forward(self, x): 415 | x = self.features(x) 416 | x = self.logits(x) 417 | return x 418 | 419 | 420 | def initialize_pretrained_model(model, num_classes, settings): 421 | assert ( 422 | num_classes == settings["num_classes"] 423 | ), "num_classes should be {}, but is {}".format( 424 | settings["num_classes"], num_classes 425 | ) 426 | model.load_state_dict(model_zoo.load_url(settings["url"])) 427 | model.input_space = settings["input_space"] 428 | model.input_size = settings["input_size"] 429 | model.input_range = settings["input_range"] 430 | model.mean = settings["mean"] 431 | model.std = settings["std"] 432 | 433 | 434 | def senet154(num_classes=1000, pretrained="imagenet"): 435 | model = SENet( 436 | SEBottleneck, 437 | [3, 8, 36, 3], 438 | groups=64, 439 | reduction=16, 440 | dropout_p=0.2, 441 | num_classes=num_classes, 442 | ) 443 | if pretrained is not None: 444 | settings = pretrained_settings["senet154"][pretrained] 445 | initialize_pretrained_model(model, num_classes, settings) 446 | return model 447 | 448 | 449 | def se_resnet50(num_classes=1000, pretrained="imagenet"): 450 | model = SENet( 451 | SEResNetBottleneck, 452 | [3, 4, 6, 3], 453 | groups=1, 454 | reduction=16, 455 | dropout_p=None, 456 | inplanes=64, 457 | input_3x3=False, 458 | downsample_kernel_size=1, 459 | downsample_padding=0, 460 | num_classes=num_classes, 461 | ) 462 | if pretrained is not None: 463 | settings = pretrained_settings["se_resnet50"][pretrained] 464 | initialize_pretrained_model(model, num_classes, settings) 465 | return model 466 | 467 | 468 | def se_resnet101(num_classes=1000, pretrained="imagenet"): 469 | model = SENet( 470 | SEResNetBottleneck, 471 | [3, 4, 23, 3], 472 | groups=1, 473 | reduction=16, 474 | dropout_p=None, 475 | inplanes=64, 476 | input_3x3=False, 477 | downsample_kernel_size=1, 478 | downsample_padding=0, 479 | num_classes=num_classes, 480 | ) 481 | if pretrained is not None: 482 | settings = pretrained_settings["se_resnet101"][pretrained] 483 | initialize_pretrained_model(model, num_classes, settings) 484 | return model 485 | 486 | 487 | def se_resnet152(num_classes=1000, pretrained="imagenet"): 488 | model = SENet( 489 | SEResNetBottleneck, 490 | [3, 8, 36, 3], 491 | groups=1, 492 | reduction=16, 493 | dropout_p=None, 494 | inplanes=64, 495 | input_3x3=False, 496 | downsample_kernel_size=1, 497 | downsample_padding=0, 498 | num_classes=num_classes, 499 | ) 500 | if pretrained is not None: 501 | settings = pretrained_settings["se_resnet152"][pretrained] 502 | initialize_pretrained_model(model, num_classes, settings) 503 | return model 504 | 505 | 506 | def se_resnext50_32x4d(num_classes=1000, pretrained="imagenet"): 507 | model = SENet( 508 | SEResNeXtBottleneck, 509 | [3, 4, 6, 3], 510 | groups=32, 511 | reduction=16, 512 | dropout_p=None, 513 | inplanes=64, 514 | input_3x3=False, 515 | downsample_kernel_size=1, 516 | downsample_padding=0, 517 | num_classes=num_classes, 518 | ) 519 | if pretrained is not None: 520 | settings = pretrained_settings["se_resnext50_32x4d"][pretrained] 521 | initialize_pretrained_model(model, num_classes, settings) 522 | return model 523 | 524 | 525 | def se_resnext101_32x4d(num_classes=1000, pretrained="imagenet"): 526 | model = SENet( 527 | SEResNeXtBottleneck, 528 | [3, 4, 23, 3], 529 | groups=32, 530 | reduction=16, 531 | dropout_p=None, 532 | inplanes=64, 533 | input_3x3=False, 534 | downsample_kernel_size=1, 535 | downsample_padding=0, 536 | num_classes=num_classes, 537 | ) 538 | if pretrained is not None: 539 | settings = pretrained_settings["se_resnext101_32x4d"][pretrained] 540 | initialize_pretrained_model(model, num_classes, settings) 541 | return model 542 | -------------------------------------------------------------------------------- /models/stnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | from functools import partial 7 | from models.resnet_2d import Bottleneck, conv1x1, conv3x3 8 | from models.temporal_xception import TemporalXception 9 | from models.senet import SEResNeXtBottleneck 10 | from collections import OrderedDict 11 | 12 | 13 | import torchvision 14 | 15 | 16 | class StNet(nn.Module): 17 | def __init__( 18 | self, 19 | block, 20 | layers, 21 | groups, 22 | reduction, 23 | dropout_p=0.2, 24 | inplanes=128, 25 | input_3x3=True, 26 | downsample_kernel_size=3, 27 | downsample_padding=1, 28 | num_classes=1000, 29 | T=7, 30 | N=5, 31 | input_channels=3, 32 | ): 33 | super(StNet, self).__init__() 34 | self.inplanes = inplanes 35 | self.T = T 36 | self.N = N 37 | layer0_modules = [ 38 | ( 39 | "conv1", 40 | nn.Conv2d( 41 | 3 * self.N, inplanes, kernel_size=7, stride=2, padding=3, bias=False 42 | ), 43 | ), 44 | ("bn1", nn.BatchNorm2d(inplanes)), 45 | ("relu1", nn.ReLU(inplace=True)), 46 | ] 47 | # To preserve compatibility with Caffe weights `ceil_mode=True` 48 | # is used instead of `padding=1`. 49 | layer0_modules.append(("pool", nn.MaxPool2d(3, stride=2, ceil_mode=True))) 50 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) 51 | self.layer1 = self._make_layer( 52 | block, 53 | planes=64, 54 | blocks=layers[0], 55 | groups=groups, 56 | reduction=reduction, 57 | downsample_kernel_size=1, 58 | downsample_padding=0, 59 | ) 60 | self.layer2 = self._make_layer( 61 | block, 62 | planes=128, 63 | blocks=layers[1], 64 | stride=2, 65 | groups=groups, 66 | reduction=reduction, 67 | downsample_kernel_size=downsample_kernel_size, 68 | downsample_padding=downsample_padding, 69 | ) 70 | 71 | self.temp1 = TemporalBlock(512) 72 | 73 | self.layer3 = self._make_layer( 74 | block, 75 | planes=256, 76 | blocks=layers[2], 77 | stride=2, 78 | groups=groups, 79 | reduction=reduction, 80 | downsample_kernel_size=downsample_kernel_size, 81 | downsample_padding=downsample_padding, 82 | ) 83 | self.temp2 = TemporalBlock(1024) 84 | self.layer4 = self._make_layer( 85 | block, 86 | planes=512, 87 | blocks=layers[3], 88 | stride=2, 89 | groups=groups, 90 | reduction=reduction, 91 | downsample_kernel_size=downsample_kernel_size, 92 | downsample_padding=downsample_padding, 93 | ) 94 | 95 | self.xception = TemporalXception(2048, 2048) 96 | self.last_linear = nn.Linear(2048, num_classes) 97 | 98 | def _make_layer( 99 | self, 100 | block, 101 | planes, 102 | blocks, 103 | groups, 104 | reduction, 105 | stride=1, 106 | downsample_kernel_size=1, 107 | downsample_padding=0, 108 | ): 109 | downsample = None 110 | if stride != 1 or self.inplanes != planes * block.expansion: 111 | downsample = nn.Sequential( 112 | nn.Conv2d( 113 | self.inplanes, 114 | planes * block.expansion, 115 | kernel_size=downsample_kernel_size, 116 | stride=stride, 117 | padding=downsample_padding, 118 | bias=False, 119 | ), 120 | nn.BatchNorm2d(planes * block.expansion), 121 | ) 122 | 123 | layers = [] 124 | layers.append( 125 | block(self.inplanes, planes, groups, reduction, stride, downsample) 126 | ) 127 | self.inplanes = planes * block.expansion 128 | for i in range(1, blocks): 129 | layers.append(block(self.inplanes, planes, groups, reduction)) 130 | 131 | return nn.Sequential(*layers) 132 | 133 | def forward(self, x): 134 | # size (batch_size, channels, video_length = T * N, height, width) 135 | B, C, L, H, W = x.size() 136 | x = x.permute(0, 2, 1, 3, 4).contiguous() 137 | assert self.T * self.N == L 138 | x = x.view(B * self.T, self.N * C, H, W) 139 | x = self.layer0(x) 140 | x = self.layer1(x) 141 | x = self.layer2(x) 142 | # size (batch_size*T, Ci, Hi, Wi) 143 | size = x.size() 144 | x = x.view(B, self.T, x.size(1), x.size(2), x.size(3)) 145 | B, T, C, H, W = x.size() 146 | x = x.permute(0, 2, 1, 3, 4) 147 | x = self.temp1(x) 148 | x = x.permute(0, 2, 1, 3, 4).contiguous() 149 | x = x.view(B * T, C, H, W) 150 | x = self.layer3(x) 151 | # size (batch_size*T, Ci, Hi, Wi) 152 | size = x.size() 153 | x = x.view(B, self.T, x.size(1), x.size(2), x.size(3)) 154 | B, T, C, H, W = x.size() 155 | x = x.permute(0, 2, 1, 3, 4) 156 | x = self.temp2(x) 157 | x = x.permute(0, 2, 1, 3, 4).contiguous() 158 | x = x.view(B * T, C, H, W) 159 | x = self.layer4(x) 160 | # size (batch_size*T, Ci, Hi, Wi) 161 | size = x.size() 162 | x = F.avg_pool2d(x, kernel_size=(size[2], size[3])) 163 | # size (batch_size*T, Ci, 1, 1) 164 | x = x.view(B, self.T, size[1]).permute(0, 2, 1) 165 | # size (batch_size, T, Ci) 166 | x = self.xception(x) 167 | x = self.last_linear(x) 168 | 169 | return x 170 | 171 | 172 | class TemporalBlock(nn.Module): 173 | def __init__(self, channels): 174 | super(TemporalBlock, self).__init__() 175 | self.channels = channels 176 | self.conv1 = nn.Conv3d( 177 | channels, 178 | channels, 179 | kernel_size=(3, 1, 1), 180 | stride=1, 181 | padding=(1, 0, 0), 182 | bias=False, 183 | ) 184 | self.bn1 = nn.BatchNorm3d(channels) 185 | self.relu = nn.ReLU(inplace=True) 186 | for m in self.modules(): 187 | if isinstance(m, nn.Conv3d): 188 | nn.init.dirac_(m.weight) 189 | # m.weight.data.fill_(1 / (3 * self.channels)) 190 | if m.bias is not None: 191 | m.bias.data.zero_() 192 | elif isinstance(m, nn.BatchNorm3d): 193 | m.weight.data.fill_(1) 194 | m.bias.data.zero_() 195 | 196 | def forward(self, x): 197 | x = self.conv1(x) 198 | x = self.bn1(x) 199 | x = self.relu(x) 200 | return x 201 | 202 | 203 | def load_weights(model, state): 204 | pretrained_dict = {} 205 | model_state = model.state_dict() 206 | for name, param in state.items(): 207 | if name.startswith("layer0.conv1"): 208 | pretrained_dict[name] = state[name].repeat(1, model.N, 1, 1) / model.N 209 | else: 210 | pretrained_dict[name] = state[name] 211 | 212 | model_state.update(pretrained_dict) 213 | model.load_state_dict(model_state) 214 | return model 215 | 216 | 217 | import pretrainedmodels 218 | 219 | 220 | def stnet50(**kwargs): 221 | """ 222 | Construct stnet with a SE-Resnext 50 backbone. 223 | """ 224 | 225 | model = StNet( 226 | SEResNeXtBottleneck, 227 | [3, 4, 6, 3], 228 | groups=32, 229 | reduction=16, 230 | dropout_p=None, 231 | inplanes=64, 232 | input_3x3=False, 233 | downsample_kernel_size=1, 234 | downsample_padding=0, 235 | **kwargs, 236 | ) 237 | model = load_weights( 238 | model, 239 | pretrainedmodels.__dict__["se_resnext50_32x4d"]( 240 | num_classes=1000, pretrained="imagenet" 241 | ).state_dict(), 242 | ) 243 | 244 | return model 245 | -------------------------------------------------------------------------------- /models/temporal_xception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | from functools import partial 7 | 8 | 9 | class TemporalXception(nn.Module): 10 | def __init__(self, in_channels, out_channels): 11 | super(TemporalXception, self).__init__() 12 | self.bn1 = nn.BatchNorm1d(in_channels) 13 | self.sepconv1 = SeparableConv1d( 14 | in_channels, out_channels, kernel_size=3, padding=1 15 | ) 16 | self.bn2 = nn.BatchNorm1d(out_channels) 17 | self.sepconv2 = SeparableConv1d( 18 | out_channels, out_channels, kernel_size=3, padding=1 19 | ) 20 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1, padding=0) 21 | 22 | self.bn3 = nn.BatchNorm1d(out_channels) 23 | for m in [self.bn1, self.bn2, self.bn3, self.conv]: 24 | if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv1d): 25 | nn.init.dirac_(m.weight) 26 | if m.bias is not None: 27 | m.bias.data.zero_() 28 | elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.BatchNorm1d): 29 | m.weight.data.fill_(1) 30 | m.bias.data.zero_() 31 | 32 | def forward(self, x): 33 | B, C, T = x.size() 34 | x = self.bn1(x) 35 | x2 = self.conv(x) 36 | x1 = self.sepconv1(x) 37 | x1 = F.relu(self.bn2(x1)) 38 | x1 = self.sepconv2(x1) 39 | # size (B, C, T) 40 | x = F.relu(self.bn3(x1 + x2)).div(2.0) 41 | x = F.max_pool1d(x, kernel_size=x.size(-1)) 42 | # size (B,C,1) 43 | return x.view(x.size(0), x.size(1)) 44 | 45 | 46 | class SeparableConv1d(nn.Module): 47 | def __init__( 48 | self, 49 | in_channels, 50 | out_channels, 51 | kernel_size=1, 52 | stride=1, 53 | padding=0, 54 | dilation=1, 55 | bias=False, 56 | ): 57 | super(SeparableConv1d, self).__init__() 58 | self.conv1 = nn.Conv1d( 59 | in_channels, 60 | in_channels, 61 | kernel_size, 62 | stride, 63 | padding, 64 | dilation, 65 | groups=in_channels, 66 | bias=bias, 67 | ) 68 | ## init seprabale conv as init 69 | self.conv1.weight.data.zero_() 70 | self.conv1.weight[:, :, kernel_size // 2].data.fill_(1) 71 | 72 | self.pointwise = nn.Conv1d( 73 | in_channels, 74 | out_channels, 75 | kernel_size=1, 76 | stride=1, 77 | padding=0, 78 | dilation=1, 79 | groups=1, 80 | bias=bias, 81 | ) 82 | nn.init.dirac_(self.pointwise.weight) 83 | 84 | for m in self.modules(): 85 | if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv1d): 86 | if m.bias is not None: 87 | m.bias.data.zero_() 88 | if isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.BatchNorm1d): 89 | m.weight.data.fill_(1) 90 | 91 | def forward(self, x): 92 | x = self.conv1(x) 93 | x = self.pointwise(x) 94 | return x 95 | --------------------------------------------------------------------------------