├── README.md ├── frn_layer ├── __init__.py └── frn_layer.py ├── preact_resnet ├── __init__.py ├── preact_resnet.py └── preact_resnet_frn.py └── test └── test_models.py /README.md: -------------------------------------------------------------------------------- 1 | # filter-response-normalization-layer-pytorch (FRN) 2 | 3 | Unofficial PyTorch implementation of Filter Response Normalization Layer. 4 | 5 | ## Make a preact ResNet50 model with FRN Layer 6 | 7 | ``` 8 | from preact_resnet import preact_resnet50_frn 9 | model = preact_resnet50_frn(num_classes=1000) 10 | ``` 11 | 12 | ## Reference 13 | 14 | [Filter Response Normalization Layer: Eliminating Batch Dependence in the Training of Deep Neural Networks](https://arxiv.org/abs/1911.09737) 15 | -------------------------------------------------------------------------------- /frn_layer/__init__.py: -------------------------------------------------------------------------------- 1 | from .frn_layer import FilterResponseNormLayer 2 | -------------------------------------------------------------------------------- /frn_layer/frn_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Parameter 4 | 5 | 6 | class FilterResponseNormLayer(nn.Module): 7 | def __init__(self, num_features, eps=1e-6): 8 | super(FilterResponseNormLayer, self).__init__() 9 | self.num_features = num_features 10 | self.tau = Parameter(torch.Tensor(1, num_features, 1, 1)) 11 | self.beta = Parameter(torch.Tensor(1, num_features, 1, 1)) 12 | self.gamma = Parameter(torch.Tensor(1, num_features, 1, 1)) 13 | self.eps = Parameter(torch.Tensor([eps])) 14 | self.reset_parameters() 15 | 16 | def reset_parameters(self): 17 | nn.init.zeros_(self.tau) 18 | nn.init.zeros_(self.beta) 19 | nn.init.ones_(self.gamma) 20 | 21 | def forward(self, input): 22 | nu2 = torch.mean(input**2, dim=(2, 3), keepdim=True, out=None) 23 | input = input * torch.rsqrt(nu2 + torch.abs(self.eps)) 24 | return torch.max(self.gamma * input + self.beta, self.tau) 25 | 26 | def extra_repr(self): 27 | return '{}'.format( 28 | self.num_features 29 | ) 30 | -------------------------------------------------------------------------------- /preact_resnet/__init__.py: -------------------------------------------------------------------------------- 1 | from .preact_resnet import preact_resnet18, preact_resnet34, preact_resnet50, preact_resnet101, preact_resnet152, preact_resnet200 2 | from .preact_resnet_frn import preact_resnet18_frn, preact_resnet34_frn, preact_resnet50_frn, preact_resnet101_frn, preact_resnet152_frn, preact_resnet200_frn 3 | -------------------------------------------------------------------------------- /preact_resnet/preact_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # Implemented with reference to 5 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 6 | # and 7 | # https://github.com/facebookarchive/fb.resnet.torch/blob/master/models/preresnet.lua 8 | 9 | __all__ = ['preact_resnet18', 'preact_resnet34', 'preact_resnet50', 10 | 'preact_resnet101', 'preact_resnet152', 'preact_resnet200'] 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=dilation, groups=groups, bias=False, dilation=dilation) 17 | 18 | 19 | def conv1x1(in_planes, out_planes, stride=1): 20 | """1x1 convolution""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 22 | 23 | 24 | class PreActBasicBlock(nn.Module): 25 | expansion = 1 26 | __constants__ = ['downsample'] 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 29 | base_width=64, dilation=1, norm_layer=None): 30 | super(PreActBasicBlock, self).__init__() 31 | if norm_layer is None: 32 | norm_layer = nn.BatchNorm2d 33 | if groups != 1 or base_width != 64: 34 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 35 | if dilation > 1: 36 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 37 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 38 | self.bn1 = norm_layer(inplanes) 39 | self.relu1 = nn.ReLU(inplace=True) 40 | self.conv1 = conv3x3(inplanes, planes, stride) 41 | self.bn2 = norm_layer(planes) 42 | self.relu2 = nn.ReLU(inplace=True) 43 | self.conv2 = conv3x3(planes, planes) 44 | self.downsample = downsample 45 | self.stride = stride 46 | 47 | def forward(self, x): 48 | identity = x 49 | 50 | out = self.bn1(x) 51 | out = self.relu1(out) 52 | out = self.conv1(out) 53 | 54 | out = self.bn2(out) 55 | out = self.relu2(out) 56 | out = self.conv2(out) 57 | 58 | if self.downsample is not None: 59 | identity = self.downsample(x) 60 | 61 | out += identity 62 | 63 | return out 64 | 65 | 66 | class PreActBottleneck(nn.Module): 67 | expansion = 4 68 | __constants__ = ['downsample'] 69 | 70 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 71 | base_width=64, dilation=1, norm_layer=None): 72 | super(PreActBottleneck, self).__init__() 73 | if norm_layer is None: 74 | norm_layer = nn.BatchNorm2d 75 | width = int(planes * (base_width / 64.)) * groups 76 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 77 | self.bn1 = norm_layer(inplanes) 78 | self.relu1 = nn.ReLU(inplace=True) 79 | self.conv1 = conv1x1(inplanes, width) 80 | self.bn2 = norm_layer(width) 81 | self.relu2 = nn.ReLU(inplace=True) 82 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 83 | self.bn3 = norm_layer(width) 84 | self.relu3 = nn.ReLU(inplace=True) 85 | self.conv3 = conv1x1(width, planes * self.expansion) 86 | self.downsample = downsample 87 | self.stride = stride 88 | 89 | def forward(self, x): 90 | identity = x 91 | 92 | out = self.bn1(x) 93 | out = self.relu1(out) 94 | out = self.conv1(out) 95 | 96 | out = self.bn2(out) 97 | out = self.relu2(out) 98 | out = self.conv2(out) 99 | 100 | out = self.bn3(out) 101 | out = self.relu3(out) 102 | out = self.conv3(out) 103 | 104 | if self.downsample is not None: 105 | identity = self.downsample(x) 106 | 107 | out += identity 108 | 109 | return out 110 | 111 | 112 | class PreActResNet(nn.Module): 113 | 114 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 115 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 116 | norm_layer=None): 117 | super(PreActResNet, self).__init__() 118 | if norm_layer is None: 119 | norm_layer = nn.BatchNorm2d 120 | self._norm_layer = norm_layer 121 | 122 | self.inplanes = 64 123 | self.dilation = 1 124 | if replace_stride_with_dilation is None: 125 | # each element in the tuple indicates if we should replace 126 | # the 2x2 stride with a dilated convolution instead 127 | replace_stride_with_dilation = [False, False, False] 128 | if len(replace_stride_with_dilation) != 3: 129 | raise ValueError("replace_stride_with_dilation should be None " 130 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 131 | self.groups = groups 132 | self.base_width = width_per_group 133 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 134 | bias=False) 135 | self.bn1 = norm_layer(self.inplanes) 136 | self.relu1 = nn.ReLU(inplace=True) 137 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 138 | self.layer1 = self._make_layer(block, 64, layers[0]) 139 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 140 | dilate=replace_stride_with_dilation[0]) 141 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 142 | dilate=replace_stride_with_dilation[1]) 143 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 144 | dilate=replace_stride_with_dilation[2]) 145 | self.bn2 = norm_layer(self.inplanes) 146 | self.relu2 = nn.ReLU(inplace=True) 147 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 148 | self.fc = nn.Linear(512 * block.expansion, num_classes) 149 | 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 153 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 154 | nn.init.constant_(m.weight, 1) 155 | nn.init.constant_(m.bias, 0) 156 | 157 | # Zero-initialize the last BN in each residual branch, 158 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 159 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 160 | if zero_init_residual: 161 | for m in self.modules(): 162 | if isinstance(m, Bottleneck): 163 | nn.init.constant_(m.bn3.weight, 0) 164 | elif isinstance(m, BasicBlock): 165 | nn.init.constant_(m.bn2.weight, 0) 166 | 167 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 168 | norm_layer = self._norm_layer 169 | downsample = None 170 | previous_dilation = self.dilation 171 | if dilate: 172 | self.dilation *= stride 173 | stride = 1 174 | if stride != 1 or self.inplanes != planes * block.expansion: 175 | downsample = nn.Sequential( 176 | conv1x1(self.inplanes, planes * block.expansion, stride), 177 | norm_layer(planes * block.expansion), 178 | ) 179 | 180 | layers = [] 181 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 182 | self.base_width, previous_dilation, norm_layer)) 183 | self.inplanes = planes * block.expansion 184 | for _ in range(1, blocks): 185 | layers.append(block(self.inplanes, planes, groups=self.groups, 186 | base_width=self.base_width, dilation=self.dilation, 187 | norm_layer=norm_layer)) 188 | 189 | return nn.Sequential(*layers) 190 | 191 | def _forward_impl(self, x): 192 | x = self.conv1(x) 193 | x = self.bn1(x) 194 | x = self.relu1(x) 195 | x = self.maxpool(x) 196 | 197 | x = self.layer1(x) 198 | x = self.layer2(x) 199 | x = self.layer3(x) 200 | x = self.layer4(x) 201 | x = self.bn2(x) 202 | x = self.relu2(x) 203 | 204 | x = self.avgpool(x) 205 | x = torch.flatten(x, 1) 206 | x = self.fc(x) 207 | 208 | return x 209 | 210 | def forward(self, x): 211 | return self._forward_impl(x) 212 | 213 | 214 | def _preact_resnet(arch, block, layers, **kwargs): 215 | model = PreActResNet(block, layers, **kwargs) 216 | return model 217 | 218 | 219 | def preact_resnet18(**kwargs): 220 | return _preact_resnet('preact_resnet18', PreActBasicBlock, [2, 2, 2, 2], **kwargs) 221 | 222 | def preact_resnet34(**kwargs): 223 | return _preact_resnet('preact_resnet34', PreActBasicBlock, [3, 4, 6, 3], **kwargs) 224 | 225 | def preact_resnet50(**kwargs): 226 | return _preact_resnet('preact_resnet50', PreActBottleneck, [3, 4, 6, 3], **kwargs) 227 | 228 | def preact_resnet101(**kwargs): 229 | return _preact_resnet('preact_resnet101', PreActBottleneck, [3, 4, 23, 3], **kwargs) 230 | 231 | def preact_resnet152(**kwargs): 232 | return _preact_resnet('preact_resnet152', PreActBottleneck, [3, 8, 36, 3], **kwargs) 233 | 234 | def preact_resnet200(**kwargs): 235 | return _preact_resnet('preact_resnet200', PreActBottleneck, [3, 24, 36, 3], **kwargs) 236 | -------------------------------------------------------------------------------- /preact_resnet/preact_resnet_frn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import os 5 | import sys 6 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 7 | 8 | from frn_layer import FilterResponseNormLayer 9 | 10 | # Implemented with reference to 11 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 12 | # and 13 | # https://github.com/facebookarchive/fb.resnet.torch/blob/master/models/preresnet.lua 14 | 15 | __all__ = ['preact_resnet18_frn', 'preact_resnet34_frn', 'preact_resnet50_frn', 16 | 'preact_resnet101_frn', 'preact_resnet152_frn', 'preact_resnet200_frn'] 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 20 | """3x3 convolution with padding""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=dilation, groups=groups, bias=False, dilation=dilation) 23 | 24 | 25 | def conv1x1(in_planes, out_planes, stride=1): 26 | """1x1 convolution""" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 28 | 29 | 30 | class PreActBasicBlockFRN(nn.Module): 31 | expansion = 1 32 | __constants__ = ['downsample'] 33 | 34 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 35 | base_width=64, dilation=1, norm_layer=None): 36 | super(PreActBasicBlockFRN, self).__init__() 37 | if norm_layer is None: 38 | norm_layer = FilterResponseNormLayer 39 | if groups != 1 or base_width != 64: 40 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 41 | if dilation > 1: 42 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 43 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 44 | self.frn1 = norm_layer(inplanes) 45 | self.conv1 = conv3x3(inplanes, planes, stride) 46 | self.frn2 = norm_layer(planes) 47 | self.conv2 = conv3x3(planes, planes) 48 | self.downsample = downsample 49 | self.stride = stride 50 | 51 | def forward(self, x): 52 | identity = x 53 | 54 | out = self.frn1(x) 55 | out = self.conv1(out) 56 | 57 | out = self.frn2(out) 58 | out = self.conv2(out) 59 | 60 | if self.downsample is not None: 61 | identity = self.downsample(x) 62 | 63 | out += identity 64 | 65 | return out 66 | 67 | 68 | class PreActBottleneckFRN(nn.Module): 69 | expansion = 4 70 | __constants__ = ['downsample'] 71 | 72 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 73 | base_width=64, dilation=1, norm_layer=None, activation_layer=None): 74 | super(PreActBottleneckFRN, self).__init__() 75 | if norm_layer is None: 76 | norm_layer = FilterResponseNormLayer 77 | width = int(planes * (base_width / 64.)) * groups 78 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 79 | self.frn1 = norm_layer(inplanes) 80 | self.conv1 = conv1x1(inplanes, width) 81 | self.frn2 = norm_layer(width) 82 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 83 | self.frn3 = norm_layer(width) 84 | self.conv3 = conv1x1(width, planes * self.expansion) 85 | self.downsample = downsample 86 | self.stride = stride 87 | 88 | def forward(self, x): 89 | identity = x 90 | 91 | out = self.frn1(x) 92 | out = self.conv1(out) 93 | 94 | out = self.frn2(out) 95 | out = self.conv2(out) 96 | 97 | out = self.frn3(out) 98 | out = self.conv3(out) 99 | 100 | if self.downsample is not None: 101 | identity = self.downsample(x) 102 | 103 | out += identity 104 | 105 | return out 106 | 107 | 108 | class PreActResNetFRN(nn.Module): 109 | 110 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 111 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 112 | norm_layer=None): 113 | super(PreActResNetFRN, self).__init__() 114 | if norm_layer is None: 115 | norm_layer = FilterResponseNormLayer 116 | self._norm_layer = norm_layer 117 | 118 | self.inplanes = 64 119 | self.dilation = 1 120 | if replace_stride_with_dilation is None: 121 | # each element in the tuple indicates if we should replace 122 | # the 2x2 stride with a dilated convolution instead 123 | replace_stride_with_dilation = [False, False, False] 124 | if len(replace_stride_with_dilation) != 3: 125 | raise ValueError("replace_stride_with_dilation should be None " 126 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 127 | self.groups = groups 128 | self.base_width = width_per_group 129 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 130 | bias=False) 131 | self.frn1 = norm_layer(self.inplanes) 132 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 133 | self.layer1 = self._make_layer(block, 64, layers[0]) 134 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 135 | dilate=replace_stride_with_dilation[0]) 136 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 137 | dilate=replace_stride_with_dilation[1]) 138 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 139 | dilate=replace_stride_with_dilation[2]) 140 | self.frn2 = norm_layer(self.inplanes) 141 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 142 | self.fc = nn.Linear(512 * block.expansion, num_classes) 143 | 144 | for m in self.modules(): 145 | if isinstance(m, nn.Conv2d): 146 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 147 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 148 | nn.init.constant_(m.weight, 1) 149 | nn.init.constant_(m.bias, 0) 150 | 151 | # Zero-initialize the last BN in each residual branch, 152 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 153 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 154 | if zero_init_residual: 155 | for m in self.modules(): 156 | if isinstance(m, Bottleneck): 157 | nn.init.constant_(m.bn3.weight, 0) 158 | elif isinstance(m, BasicBlock): 159 | nn.init.constant_(m.bn2.weight, 0) 160 | 161 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 162 | norm_layer = self._norm_layer 163 | downsample = None 164 | previous_dilation = self.dilation 165 | if dilate: 166 | self.dilation *= stride 167 | stride = 1 168 | if stride != 1 or self.inplanes != planes * block.expansion: 169 | downsample = nn.Sequential( 170 | conv1x1(self.inplanes, planes * block.expansion, stride), 171 | norm_layer(planes * block.expansion), 172 | ) 173 | 174 | layers = [] 175 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 176 | self.base_width, previous_dilation, norm_layer)) 177 | self.inplanes = planes * block.expansion 178 | for _ in range(1, blocks): 179 | layers.append(block(self.inplanes, planes, groups=self.groups, 180 | base_width=self.base_width, dilation=self.dilation, 181 | norm_layer=norm_layer)) 182 | 183 | return nn.Sequential(*layers) 184 | 185 | def _forward_impl(self, x): 186 | # See note [TorchScript super()] 187 | x = self.conv1(x) 188 | x = self.frn1(x) 189 | x = self.maxpool(x) 190 | 191 | x = self.layer1(x) 192 | x = self.layer2(x) 193 | x = self.layer3(x) 194 | x = self.layer4(x) 195 | x = self.frn2(x) 196 | 197 | x = self.avgpool(x) 198 | x = torch.flatten(x, 1) 199 | x = self.fc(x) 200 | 201 | return x 202 | 203 | def forward(self, x): 204 | return self._forward_impl(x) 205 | 206 | 207 | def _preact_resnet_frn(arch, block, layers, **kwargs): 208 | model = PreActResNetFRN(block, layers, **kwargs) 209 | return model 210 | 211 | 212 | def preact_resnet18_frn(**kwargs): 213 | return _preact_resnet_frn('preact_resnet18_frn', PreActBasicBlockFRN, [2, 2, 2, 2], **kwargs) 214 | 215 | def preact_resnet34_frn(**kwargs): 216 | return _preact_resnet_frn('preact_resnet34_frn', PreActBasicBlockFRN, [3, 4, 6, 3], **kwargs) 217 | 218 | def preact_resnet50_frn(**kwargs): 219 | return _preact_resnet_frn('preact_resnet50_frn', PreActBottleneckFRN, [3, 4, 6, 3], **kwargs) 220 | 221 | def preact_resnet101_frn(**kwargs): 222 | return _preact_resnet_frn('preact_resnet101_frn', PreActBottleneckFRN, [3, 4, 23, 3], **kwargs) 223 | 224 | def preact_resnet152_frn(**kwargs): 225 | return _preact_resnet_frn('preact_resnet152_frn', PreActBottleneckFRN, [3, 8, 36, 3], **kwargs) 226 | 227 | def preact_resnet200_frn(**kwargs): 228 | return _preact_resnet_frn('preact_resnet200_frn', PreActBottleneckFRN, [3, 24, 36, 3], **kwargs) 229 | -------------------------------------------------------------------------------- /test/test_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | 5 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 6 | 7 | from preact_resnet import * 8 | 9 | model_dict = { 10 | 'preact_resnet18': preact_resnet18, 11 | 'preact_resnet34': preact_resnet34, 12 | 'preact_resnet50': preact_resnet50, 13 | 'preact_resnet101': preact_resnet101, 14 | 'preact_resnet152': preact_resnet152, 15 | 'preact_resnet200': preact_resnet200, 16 | 'preact_resnet18_frn': preact_resnet18_frn, 17 | 'preact_resnet34_frn': preact_resnet34_frn, 18 | 'preact_resnet50_frn': preact_resnet50_frn, 19 | 'preact_resnet101_frn': preact_resnet101_frn, 20 | 'preact_resnet152_frn': preact_resnet152_frn, 21 | 'preact_resnet200_frn': preact_resnet200_frn, 22 | } 23 | 24 | input = torch.randn(2,3,224,224) 25 | 26 | for model_name, model in model_dict.items(): 27 | print("check:", str(model_name)) 28 | current_model = model(num_classes=10) 29 | output = current_model(input) 30 | 31 | print('ok') 32 | --------------------------------------------------------------------------------