├── LICENSE.md ├── README.md ├── asap.py ├── flc_pooling.py ├── preact_resnet_flc.py ├── resnet_asap.py ├── resnet_flc.py └── wide_resnet_flc.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Grabinski 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FrequencyLowCut Pooling 2 | Code for [FrequencyLowCut Pooling - Plug & Play against Catastrophic Overfitting](https://link.springer.com/chapter/10.1007/978-3-031-19781-9_3) and [Fix your downsampling ASAP! Be natively more robust via Aliasing and Spectral Artifact free Pooling](https://arxiv.org/abs/2307.09804) 3 | 4 | We provide our FrequencyLowCut (FLC) module and our Aliasing and Sinc Artifact free Pooling (ASAP) as well as examples how to implement them into common CNN structures. 5 | 6 | The code for adversarial training used in our paper can be found [here](https://github.com/locuslab/fast_adversarial). 7 | 8 | 9 | ## Citation 10 | 11 | Would you like to reference our **`FLC Pooling`** and **`ASAP`**? 12 | 13 | Then consider citing our [paper](https://link.springer.com/chapter/10.1007/978-3-031-19781-9_3) and [paper](https://arxiv.org/abs/2307.09804): 14 | 15 | 16 | ```bibtex 17 | @inproceedings{grabinski2022frequencylowcut, 18 | title = {FrequencyLowCut Pooling--Plug \& Play against Catastrophic Overfitting}, 19 | author = {Grabinski, Julia and Jung, Steffen and Keuper, Janis and Keuper, Margret}, 20 | booktitle = {European Conference on Computer Vision}, 21 | year = {2022}, 22 | url = {https://arxiv.org/abs/2204.00491} 23 | } 24 | 25 | @article{grabinski2023fix, 26 | title = {Fix your downsampling ASAP! Be natively more robust via Aliasing and Spectral Artifact free Pooling}, 27 | author = {Grabinski, Julia and Keuper, Janis and Keuper, Margret}, 28 | journal = {arXiv preprint arXiv:2307.09804}, 29 | year = {2023} 30 | } 31 | ``` 32 | -------------------------------------------------------------------------------- /asap.py: -------------------------------------------------------------------------------- 1 | '''ASAP module 2 | can be used and distributed under the MIT license 3 | Reference: 4 | [1] Grabinski, J., Jung, S., Keuper, J., & Keuper, M. (2022). 5 | "FrequencyLowCut Pooling--Plug & Play against Catastrophic Overfitting." 6 | European Conference on Computer Vision. Cham: Springer Nature Switzerland, 2022. 7 | [2] Grabinski, J., Keuper, J. and Keuper, M. 8 | "Fix your downsampling ASAP! Be natively more robust via Aliasing and Spectral Artifact free Pooling." 9 | arXiv preprint arXiv:2307.09804 (2023). 10 | ''' 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | import numpy as np 17 | 18 | 19 | class ASAP(nn.Module): 20 | # pooling trough selecting only the low frequent part in the fourier domain and only using this part to go back into the spatial domain 21 | # save computations as we do not need to do the downsampling trough conv with stride 2 22 | # using a hamming window to prevent sinc-interpolation artifacts 23 | def __init__(self, transpose=True): 24 | self.transpose = transpose 25 | self.window2d = None 26 | super(ASAP, self).__init__() 27 | 28 | def forward(self, x): 29 | 30 | if self.transpose: 31 | x = x.transpose(2,3) 32 | if self.window2d is None: 33 | window1d = np.abs(np.hamming(x.size(2))) 34 | window2nd = np.abs(np.hamming(x.size(3))) 35 | window2d = np.sqrt(np.outer(window1d,window2nd)) 36 | self.window2d = torch.Tensor(window2d).cuda() 37 | del window1d 38 | del window2d 39 | del window2nd 40 | 41 | low_part = torch.fft.fftshift(torch.fft.fft2(x, norm='forward')) 42 | low_part = low_part*self.window2d.unsqueeze(0).unsqueeze(0) 43 | low_part = low_part[:,:,int(x.size(2)/4):int(x.size(2)/4*3),int(x.size(3)/4):int(x.size(3)/4*3)] 44 | 45 | return torch.fft.ifft2(torch.fft.ifftshift(low_part), norm='forward').real 46 | 47 | 48 | class ASAP_padding_one(nn.Module): 49 | # pooling trough selecting only the low frequent part in the fourier domain and only using this part to go back into the spatial domain 50 | # save computations as we do not need to do the downsampling trough conv with stride 2 51 | # using a hamming window to prevent sinc-interpolation artifacts 52 | def __init__(self): 53 | self.window2d = None 54 | super(ASAP_padding_one, self).__init__() 55 | 56 | def forward(self, x): 57 | 58 | x = F.pad(x, (0, 1, 0, 1), "constant", 0) 59 | if not torch.is_tensor(self.window2d): 60 | window1d = np.abs(np.hamming(x.size(2))) 61 | window2nd = np.abs(np.hamming(x.size(3))) 62 | window2d = np.sqrt(np.outer(window1d,window2nd)) 63 | self.window2d = torch.Tensor(window2d).cuda() 64 | del window1d 65 | del window2d 66 | del window2nd 67 | 68 | low_part = torch.fft.fftshift(torch.fft.fft2(x, norm='forward')) 69 | low_part = low_part*self.window2d.unsqueeze(0).unsqueeze(0) 70 | low_part = low_part[:,:,int(x.size(2)/4):int(x.size(2)/4*3),int(x.size(3)/4):int(x.size(3)/4*3)] 71 | 72 | fc = torch.fft.ifft2(torch.fft.ifftshift(low_part), norm='forward').real 73 | return fc 74 | 75 | 76 | class ASAP_padding_large(nn.Module): 77 | # pooling trough selecting only the low frequent part in the fourier domain and only using this part to go back into the spatial domain 78 | # save computations as we do not need to do the downsampling trough conv with stride 2 79 | # using a hamming window to prevent sinc-interpolation artifacts 80 | def __init__(self): 81 | self.window2d = None 82 | super(ASAP_padding_large, self).__init__() 83 | 84 | def forward(self, x): 85 | 86 | x = F.pad(x, (int(x.size(3)/2-1), int(x.size(3)/2), int(x.size(2)/2-1), int(x.size(2)/2)), "constant", 0) 87 | if not torch.is_tensor(self.window2d): 88 | window1d = np.abs(np.hamming(x.size(2))) 89 | window2nd = np.abs(np.hamming(x.size(3))) 90 | window2d = np.sqrt(np.outer(window1d,window2nd)) 91 | self.window2d = torch.Tensor(window2d).cuda() 92 | del window1d 93 | del window2d 94 | del window2nd 95 | 96 | low_part = torch.fft.fftshift(torch.fft.fft2(x, norm='forward')) 97 | low_part = low_part*self.window2d.unsqueeze(0).unsqueeze(0) 98 | low_part = low_part[:,:,int(x.size(2)/4):int(x.size(2)/4*3),int(x.size(3)/4):int(x.size(3)/4*3)] 99 | 100 | fc = torch.fft.ifft2(torch.fft.ifftshift(low_part), norm='forward').real 101 | fc = fc[:,:,int(fc.size(2)/4):int(3*fc.size(2)/4),int(fc.size(3)/4): int(3*fc.size(3)/4)] 102 | return fc 103 | -------------------------------------------------------------------------------- /flc_pooling.py: -------------------------------------------------------------------------------- 1 | '''FLC Pooling module 2 | can be used and distributed under the MIT license 3 | Reference: 4 | [1] Grabinski, J., Jung, S., Keuper, J., & Keuper, M. (2022). 5 | "FrequencyLowCut Pooling--Plug & Play against Catastrophic Overfitting." 6 | European Conference on Computer Vision. Cham: Springer Nature Switzerland, 2022. 7 | ''' 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | class FLC_Pooling(nn.Module): 14 | # pooling trough selecting only the low frequent part in the fourier domain and only using this part to go back into the spatial domain 15 | # save computations as we do not need to do the downsampling trough conv with stride 2 16 | def __init__(self): 17 | super(FLC_Pooling, self).__init__() 18 | 19 | def forward(self, x): 20 | 21 | low_part = torch.fft.fftshift(torch.fft.fft2(x, norm='forward'))[:,:,int(x.size(2)/4):int(x.size(2)/4*3),int(x.size(3)/4):int(x.size(3)/4*3)] 22 | 23 | return torch.fft.ifft2(torch.fft.ifftshift(low_part), norm='forward').real 24 | -------------------------------------------------------------------------------- /preact_resnet_flc.py: -------------------------------------------------------------------------------- 1 | '''Generic Class for PreAct ResNet with FLC Pooling 2 | Reference: 3 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 4 | Identity Mappings in Deep Residual Networks. arXiv:1603.05027 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from flc_pooling import FLC_Pooling 11 | 12 | 13 | class PreActBlock(nn.Module): 14 | '''Pre-activation version of the BasicBlock.''' 15 | expansion = 1 16 | 17 | def __init__(self, in_planes, planes, stride=1, drop=0): 18 | super(PreActBlock, self).__init__() 19 | self.bn1 = nn.BatchNorm2d(in_planes) 20 | if stride == 1: 21 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 22 | else: 23 | self.conv1 = nn.Sequential( 24 | FLC_Pooling(), 25 | nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 28 | 29 | if stride != 1 or in_planes != self.expansion*planes: 30 | self.shortcut = nn.Sequential( 31 | FLC_Pooling(), 32 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=1, bias=False) 33 | ) 34 | 35 | def forward(self, x): 36 | out = F.relu(self.bn1(x)) 37 | shortcut = self.shortcut(x) if hasattr(self, 'shortcut') else x 38 | out = self.conv1(out) 39 | out = self.conv2(F.relu(self.bn2(out))) 40 | out += shortcut 41 | return out 42 | 43 | 44 | class PreActBottleneck(nn.Module): 45 | '''Pre-activation version of the original Bottleneck module.''' 46 | expansion = 4 47 | 48 | def __init__(self, in_planes, planes, stride=1, drop=0): 49 | super(PreActBottleneck, self).__init__() 50 | self.bn1 = nn.BatchNorm2d(in_planes) 51 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | if stride == 1: 54 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 55 | else: 56 | nn.Sequential( 57 | FLC_Pooling(), 58 | nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)) 59 | self.bn3 = nn.BatchNorm2d(planes) 60 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 61 | 62 | if stride != 1 or in_planes != self.expansion*planes: 63 | self.shortcut = nn.Sequential( 64 | FLC_Pooling(), 65 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=1, bias=False) 66 | ) 67 | 68 | def forward(self, x): 69 | out = F.relu(self.bn1(x)) 70 | shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x 71 | out = self.conv1(out) 72 | out = self.conv2(F.relu(self.bn2(out))) 73 | out = self.conv3(F.relu(self.bn3(out))) 74 | out += shortcut 75 | return out 76 | 77 | 78 | class PreActResNet(nn.Module): 79 | def __init__(self, block, num_blocks, num_classes=10, drop=0): 80 | super(PreActResNet, self).__init__() 81 | self.in_planes = 64 82 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 83 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 84 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 85 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 86 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 87 | self.bn = nn.BatchNorm2d(512 * block.expansion) 88 | self.linear = nn.Linear(512 * block.expansion, num_classes) 89 | 90 | def _make_layer(self, block, planes, num_blocks, stride): 91 | strides = [stride] + [1]*(num_blocks-1) 92 | layers = [] 93 | for stride in strides: 94 | layers.append(block(self.in_planes, planes, stride)) 95 | self.in_planes = planes * block.expansion 96 | return nn.Sequential(*layers) 97 | 98 | def forward(self, x): 99 | out = self.conv1(x) 100 | out = self.layer1(out) 101 | out = self.layer2(out) 102 | out = self.layer3(out) 103 | out = self.layer4(out) 104 | out = F.relu(self.bn(out)) 105 | out = F.avg_pool2d(out, 4) 106 | out = out.view(out.size(0), -1) 107 | out = self.linear(out) 108 | return out 109 | 110 | 111 | def PreActResNet18(num_classes=10): 112 | return PreActResNet(PreActBlock, [2,2,2,2], num_classes=num_classes) 113 | 114 | 115 | 116 | class PreActResNet_normalized(PreActResNet): 117 | def __init__(self, block, num_blocks, num_classes=10, mu=[0.4914, 0.4822, 0.4465], sigma=[0.2471, 0.2435, 0.2616], device='cuda'): 118 | super(PreActResNet_normalized, self).__init__(block=block, num_blocks=num_blocks, num_classes=num_classes) 119 | self.mu = torch.Tensor(mu).float().view(3, 1, 1).to(device) 120 | self.sigma = torch.Tensor(sigma).float().view(3, 1, 1).to(device) 121 | 122 | def forward(self, x): 123 | x = (x - self.mu) / self.sigma 124 | return super(PreActResNet_normalized, self).forward(x) 125 | 126 | def PreActResNet18_normalized(num_classes=10, mu=[0.4914, 0.4822, 0.4465], sigma=[0.2471, 0.2435, 0.2616], device='cuda'): 127 | return PreActResNet_normalized(PreActBlock, [2,2,2,2], num_classes=num_classes, mu=mu, sigma=sigma, device=device) 128 | -------------------------------------------------------------------------------- /resnet_asap.py: -------------------------------------------------------------------------------- 1 | """ Generic Class for ResNet with ASAP 2 | Based on code from https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py """ 3 | from functools import partial 4 | from typing import Any, Callable, List, Optional, Type, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch import Tensor 9 | 10 | import numpy as np 11 | # choose which ASAP version you want to use 12 | # default is ASAP with small padding: ASAP_padding_one 13 | from asap import ASAP, ASAP_padding_one, ASAP_padding_large 14 | 15 | __all__ = [ 16 | "ResNet", 17 | "resnet18", 18 | "resnet34", 19 | "resnet50", 20 | "resnet101", 21 | "resnet152", 22 | ] 23 | 24 | 25 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 26 | """3x3 convolution with padding""" 27 | if stride == 1: 28 | return nn.Conv2d( 29 | in_planes, 30 | out_planes, 31 | kernel_size=3, 32 | stride=1, 33 | padding=dilation, 34 | groups=groups, 35 | bias=False, 36 | dilation=dilation, 37 | ) 38 | else: 39 | return nn.Sequential(ASAP_padding_one(), 40 | nn.Conv2d( 41 | in_planes, 42 | out_planes, 43 | kernel_size=3, 44 | stride=1, 45 | padding=dilation, 46 | groups=groups, 47 | bias=False, 48 | dilation=dilation, 49 | )) 50 | 51 | 52 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 53 | """1x1 convolution""" 54 | if stride == 1: 55 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False) 56 | else: 57 | return nn.Sequential(ASAP_padding_one(), nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False)) 58 | 59 | 60 | class BasicBlock(nn.Module): 61 | expansion: int = 1 62 | 63 | def __init__( 64 | self, 65 | inplanes: int, 66 | planes: int, 67 | stride: int = 1, 68 | downsample: Optional[nn.Module] = None, 69 | groups: int = 1, 70 | base_width: int = 64, 71 | dilation: int = 1, 72 | norm_layer: Optional[Callable[..., nn.Module]] = None 73 | ) -> None: 74 | super().__init__() 75 | if norm_layer is None: 76 | norm_layer = nn.BatchNorm2d 77 | if groups != 1 or base_width != 64: 78 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 79 | if dilation > 1: 80 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 81 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 82 | self.conv1 = conv3x3(inplanes, planes, stride) 83 | self.bn1 = norm_layer(planes) 84 | self.relu = nn.ReLU(inplace=True) 85 | self.conv2 = conv3x3(planes, planes) 86 | self.bn2 = norm_layer(planes) 87 | self.downsample = downsample 88 | self.stride = stride 89 | 90 | def forward(self, x: Tensor) -> Tensor: 91 | identity = x 92 | 93 | out = self.conv1(x) 94 | out = self.bn1(out) 95 | out = self.relu(out) 96 | 97 | out = self.conv2(out) 98 | out = self.bn2(out) 99 | 100 | if self.downsample is not None: 101 | identity = self.downsample(x) 102 | 103 | out += identity 104 | out = self.relu(out) 105 | 106 | return out 107 | 108 | 109 | class Bottleneck(nn.Module): 110 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 111 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 112 | # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385. 113 | # This variant is also known as ResNet V1.5 and improves accuracy according to 114 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 115 | 116 | expansion: int = 4 117 | 118 | def __init__( 119 | self, 120 | inplanes: int, 121 | planes: int, 122 | stride: int = 1, 123 | downsample: Optional[nn.Module] = None, 124 | groups: int = 1, 125 | base_width: int = 64, 126 | dilation: int = 1, 127 | norm_layer: Optional[Callable[..., nn.Module]] = None 128 | ) -> None: 129 | super().__init__() 130 | if norm_layer is None: 131 | norm_layer = nn.BatchNorm2d 132 | width = int(planes * (base_width / 64.0)) * groups 133 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 134 | self.conv1 = conv1x1(inplanes, width) 135 | self.bn1 = norm_layer(width) 136 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 137 | self.bn2 = norm_layer(width) 138 | self.conv3 = conv1x1(width, planes * self.expansion) 139 | self.bn3 = norm_layer(planes * self.expansion) 140 | self.relu = nn.ReLU(inplace=True) 141 | self.downsample = downsample 142 | self.stride = stride 143 | 144 | def forward(self, x: Tensor) -> Tensor: 145 | identity = x 146 | 147 | out = self.conv1(x) 148 | out = self.bn1(out) 149 | out = self.relu(out) 150 | 151 | out = self.conv2(out) 152 | out = self.bn2(out) 153 | out = self.relu(out) 154 | 155 | out = self.conv3(out) 156 | out = self.bn3(out) 157 | 158 | if self.downsample is not None: 159 | identity = self.downsample(x) 160 | 161 | out += identity 162 | out = self.relu(out) 163 | 164 | return out 165 | 166 | 167 | class ResNet(nn.Module): 168 | def __init__( 169 | self, 170 | block: Type[Union[BasicBlock, Bottleneck]], 171 | layers: List[int], 172 | num_classes: int = 1000, 173 | zero_init_residual: bool = False, 174 | groups: int = 1, 175 | width_per_group: int = 64, 176 | replace_stride_with_dilation: Optional[List[bool]] = None, 177 | norm_layer: Optional[Callable[..., nn.Module]] = None, 178 | ) -> None: 179 | super().__init__() 180 | # _log_api_usage_once(self) 181 | if norm_layer is None: 182 | norm_layer = nn.BatchNorm2d 183 | self._norm_layer = norm_layer 184 | 185 | self.inplanes = 64 186 | self.dilation = 1 187 | if replace_stride_with_dilation is None: 188 | # each element in the tuple indicates if we should replace 189 | # the 2x2 stride with a dilated convolution instead 190 | replace_stride_with_dilation = [False, False, False] 191 | if len(replace_stride_with_dilation) != 3: 192 | raise ValueError( 193 | "replace_stride_with_dilation should be None " 194 | f"or a 3-element tuple, got {replace_stride_with_dilation}" 195 | ) 196 | self.groups = groups 197 | self.base_width = width_per_group 198 | self.conv1 = nn.Sequential(ASAP_padding_one(), nn.Conv2d(3, self.inplanes, kernel_size=7, stride=1, padding=3, bias=False)) 199 | self.bn1 = norm_layer(self.inplanes) 200 | self.relu = nn.ReLU(inplace=True) 201 | self.maxpool = ASAP_padding_one() # nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 202 | self.layer1 = self._make_layer(block, 64, layers[0]) 203 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 204 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 205 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 206 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 207 | self.fc = nn.Linear(512 * block.expansion, num_classes) 208 | 209 | for m in self.modules(): 210 | if isinstance(m, nn.Conv2d): 211 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 212 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 213 | nn.init.constant_(m.weight, 1) 214 | nn.init.constant_(m.bias, 0) 215 | 216 | # Zero-initialize the last BN in each residual branch, 217 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 218 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 219 | if zero_init_residual: 220 | for m in self.modules(): 221 | if isinstance(m, Bottleneck) and m.bn3.weight is not None: 222 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 223 | elif isinstance(m, BasicBlock) and m.bn2.weight is not None: 224 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 225 | 226 | def _make_layer( 227 | self, 228 | block: Type[Union[BasicBlock, Bottleneck]], 229 | planes: int, 230 | blocks: int, 231 | stride: int = 1, 232 | dilate: bool = False, 233 | ) -> nn.Sequential: 234 | norm_layer = self._norm_layer 235 | downsample = None 236 | previous_dilation = self.dilation 237 | if dilate: 238 | self.dilation *= stride 239 | stride = 1 240 | if stride != 1 or self.inplanes != planes * block.expansion: 241 | downsample = nn.Sequential( 242 | conv1x1(self.inplanes, planes * block.expansion, stride), 243 | norm_layer(planes * block.expansion), 244 | ) 245 | 246 | layers = [] 247 | layers.append( 248 | block( 249 | self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer 250 | ) 251 | ) 252 | self.inplanes = planes * block.expansion 253 | for _ in range(1, blocks): 254 | layers.append( 255 | block( 256 | self.inplanes, 257 | planes, 258 | groups=self.groups, 259 | base_width=self.base_width, 260 | dilation=self.dilation, 261 | norm_layer=norm_layer, 262 | ) 263 | ) 264 | 265 | return nn.Sequential(*layers) 266 | 267 | def _forward_impl(self, x: Tensor) -> Tensor: 268 | # See note [TorchScript super()] 269 | x = self.conv1(x) 270 | x = self.bn1(x) 271 | x = self.relu(x) 272 | x = self.maxpool(x) 273 | 274 | x = self.layer1(x) 275 | x = self.layer2(x) 276 | x = self.layer3(x) 277 | x = self.layer4(x) 278 | 279 | x = self.avgpool(x) 280 | x = torch.flatten(x, 1) 281 | x = self.fc(x) 282 | 283 | return x 284 | 285 | def forward(self, x: Tensor) -> Tensor: 286 | return self._forward_impl(x) 287 | 288 | 289 | def _resnet( 290 | block: Type[Union[BasicBlock, Bottleneck]], 291 | layers: List[int], 292 | # weights: Optional[WeightsEnum], 293 | progress: bool, 294 | **kwargs: Any, 295 | ) -> ResNet: 296 | # if weights is not None: 297 | # _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) 298 | 299 | model = ResNet(block, layers, **kwargs) 300 | 301 | # if weights is not None: 302 | # model.load_state_dict(weights.get_state_dict(progress=progress)) 303 | 304 | return model 305 | 306 | 307 | 308 | 309 | def resnet18(*, weights = None, progress: bool = True, **kwargs: Any) -> ResNet: 310 | """ResNet-18 from `Deep Residual Learning for Image Recognition `__. 311 | 312 | """ 313 | 314 | return _resnet(BasicBlock, [2, 2, 2, 2], progress, **kwargs) 315 | 316 | 317 | 318 | def resnet34(*, weights = None, progress: bool = True, **kwargs: Any) -> ResNet: 319 | """ResNet-34 from `Deep Residual Learning for Image Recognition `__. 320 | 321 | """ 322 | 323 | return _resnet(BasicBlock, [3, 4, 6, 3], progress, **kwargs) 324 | 325 | 326 | def resnet50(*, weights = None, progress: bool = True, **kwargs: Any) -> ResNet: 327 | """ResNet-50 from `Deep Residual Learning for Image Recognition `__. 328 | 329 | .. note:: 330 | The bottleneck of TorchVision places the stride for downsampling to the second 3x3 331 | convolution while the original paper places it to the first 1x1 convolution. 332 | This variant improves the accuracy and is known as `ResNet V1.5 333 | `_. 334 | 335 | """ 336 | 337 | return _resnet(Bottleneck, [3, 4, 6, 3], progress, **kwargs) 338 | 339 | 340 | def resnet101(*, weights = None, progress: bool = True, **kwargs: Any) -> ResNet: 341 | """ResNet-101 from `Deep Residual Learning for Image Recognition `__. 342 | 343 | .. note:: 344 | The bottleneck of TorchVision places the stride for downsampling to the second 3x3 345 | convolution while the original paper places it to the first 1x1 convolution. 346 | This variant improves the accuracy and is known as `ResNet V1.5 347 | `_. 348 | 349 | """ 350 | 351 | return _resnet(Bottleneck, [3, 4, 23, 3], progress, **kwargs) 352 | 353 | 354 | def resnet152(*, weights = None, progress: bool = True, **kwargs: Any) -> ResNet: 355 | """ResNet-152 from `Deep Residual Learning for Image Recognition `__. 356 | 357 | .. note:: 358 | The bottleneck of TorchVision places the stride for downsampling to the second 3x3 359 | convolution while the original paper places it to the first 1x1 convolution. 360 | This variant improves the accuracy and is known as `ResNet V1.5 361 | `_. 362 | """ 363 | 364 | return _resnet(Bottleneck, [3, 8, 36, 3], progress, **kwargs) 365 | -------------------------------------------------------------------------------- /resnet_flc.py: -------------------------------------------------------------------------------- 1 | """ Generic Class for ResNet with FLC Pooling 2 | Based on code from https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py """ 3 | from functools import partial 4 | from typing import Any, Callable, List, Optional, Type, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch import Tensor 9 | 10 | import numpy as np 11 | from flc_pooling import FLC_Pooling 12 | 13 | __all__ = [ 14 | "ResNet", 15 | "resnet18", 16 | "resnet34", 17 | "resnet50", 18 | "resnet101", 19 | "resnet152", 20 | ] 21 | 22 | 23 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 24 | """3x3 convolution with padding""" 25 | if stride == 1: 26 | return nn.Conv2d( 27 | in_planes, 28 | out_planes, 29 | kernel_size=3, 30 | stride=1, 31 | padding=dilation, 32 | groups=groups, 33 | bias=False, 34 | dilation=dilation, 35 | ) 36 | else: 37 | return nn.Sequential(FLC_Pooling(), 38 | nn.Conv2d( 39 | in_planes, 40 | out_planes, 41 | kernel_size=3, 42 | stride=1, 43 | padding=dilation, 44 | groups=groups, 45 | bias=False, 46 | dilation=dilation, 47 | )) 48 | 49 | 50 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 51 | """1x1 convolution""" 52 | if stride == 1: 53 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False) 54 | else: 55 | return nn.Sequential(FLC_Pooling(), nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False)) 56 | 57 | 58 | class BasicBlock(nn.Module): 59 | expansion: int = 1 60 | 61 | def __init__( 62 | self, 63 | inplanes: int, 64 | planes: int, 65 | stride: int = 1, 66 | downsample: Optional[nn.Module] = None, 67 | groups: int = 1, 68 | base_width: int = 64, 69 | dilation: int = 1, 70 | norm_layer: Optional[Callable[..., nn.Module]] = None 71 | ) -> None: 72 | super().__init__() 73 | if norm_layer is None: 74 | norm_layer = nn.BatchNorm2d 75 | if groups != 1 or base_width != 64: 76 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 77 | if dilation > 1: 78 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 79 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 80 | self.conv1 = conv3x3(inplanes, planes, stride) 81 | self.bn1 = norm_layer(planes) 82 | self.relu = nn.ReLU(inplace=True) 83 | self.conv2 = conv3x3(planes, planes) 84 | self.bn2 = norm_layer(planes) 85 | self.downsample = downsample 86 | self.stride = stride 87 | 88 | def forward(self, x: Tensor) -> Tensor: 89 | identity = x 90 | 91 | out = self.conv1(x) 92 | out = self.bn1(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv2(out) 96 | out = self.bn2(out) 97 | 98 | if self.downsample is not None: 99 | identity = self.downsample(x) 100 | 101 | out += identity 102 | out = self.relu(out) 103 | 104 | return out 105 | 106 | 107 | class Bottleneck(nn.Module): 108 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 109 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 110 | # according to "Deep residual learning for image recognition" https://arxiv.org/abs/1512.03385. 111 | # This variant is also known as ResNet V1.5 and improves accuracy according to 112 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 113 | 114 | expansion: int = 4 115 | 116 | def __init__( 117 | self, 118 | inplanes: int, 119 | planes: int, 120 | stride: int = 1, 121 | downsample: Optional[nn.Module] = None, 122 | groups: int = 1, 123 | base_width: int = 64, 124 | dilation: int = 1, 125 | norm_layer: Optional[Callable[..., nn.Module]] = None, 126 | ) -> None: 127 | super().__init__() 128 | if norm_layer is None: 129 | norm_layer = nn.BatchNorm2d 130 | width = int(planes * (base_width / 64.0)) * groups 131 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 132 | self.conv1 = conv1x1(inplanes, width) 133 | self.bn1 = norm_layer(width) 134 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 135 | self.bn2 = norm_layer(width) 136 | self.conv3 = conv1x1(width, planes * self.expansion) 137 | self.bn3 = norm_layer(planes * self.expansion) 138 | self.relu = nn.ReLU(inplace=True) 139 | self.downsample = downsample 140 | self.stride = stride 141 | 142 | def forward(self, x: Tensor) -> Tensor: 143 | identity = x 144 | 145 | out = self.conv1(x) 146 | out = self.bn1(out) 147 | out = self.relu(out) 148 | 149 | out = self.conv2(out) 150 | out = self.bn2(out) 151 | out = self.relu(out) 152 | 153 | out = self.conv3(out) 154 | out = self.bn3(out) 155 | 156 | if self.downsample is not None: 157 | identity = self.downsample(x) 158 | 159 | out += identity 160 | out = self.relu(out) 161 | 162 | return out 163 | 164 | 165 | class ResNet(nn.Module): 166 | def __init__( 167 | self, 168 | block: Type[Union[BasicBlock, Bottleneck]], 169 | layers: List[int], 170 | num_classes: int = 1000, 171 | zero_init_residual: bool = False, 172 | groups: int = 1, 173 | width_per_group: int = 64, 174 | replace_stride_with_dilation: Optional[List[bool]] = None, 175 | norm_layer: Optional[Callable[..., nn.Module]] = None, 176 | ) -> None: 177 | super().__init__() 178 | # _log_api_usage_once(self) 179 | if norm_layer is None: 180 | norm_layer = nn.BatchNorm2d 181 | self._norm_layer = norm_layer 182 | 183 | self.inplanes = 64 184 | self.dilation = 1 185 | if replace_stride_with_dilation is None: 186 | # each element in the tuple indicates if we should replace 187 | # the 2x2 stride with a dilated convolution instead 188 | replace_stride_with_dilation = [False, False, False] 189 | if len(replace_stride_with_dilation) != 3: 190 | raise ValueError( 191 | "replace_stride_with_dilation should be None " 192 | f"or a 3-element tuple, got {replace_stride_with_dilation}" 193 | ) 194 | self.groups = groups 195 | self.base_width = width_per_group 196 | self.conv1 = nn.Sequential(FLC_Pooling(), nn.Conv2d(3, self.inplanes, kernel_size=7, stride=1, padding=3, bias=False)) 197 | self.bn1 = norm_layer(self.inplanes) 198 | self.relu = nn.ReLU(inplace=True) 199 | self.maxpool = FLC_Pooling() # nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 200 | self.layer1 = self._make_layer(block, 64, layers[0]) 201 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) 202 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) 203 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) 204 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 205 | self.fc = nn.Linear(512 * block.expansion, num_classes) 206 | 207 | for m in self.modules(): 208 | if isinstance(m, nn.Conv2d): 209 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 210 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 211 | nn.init.constant_(m.weight, 1) 212 | nn.init.constant_(m.bias, 0) 213 | 214 | # Zero-initialize the last BN in each residual branch, 215 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 216 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 217 | if zero_init_residual: 218 | for m in self.modules(): 219 | if isinstance(m, Bottleneck) and m.bn3.weight is not None: 220 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 221 | elif isinstance(m, BasicBlock) and m.bn2.weight is not None: 222 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 223 | 224 | def _make_layer( 225 | self, 226 | block: Type[Union[BasicBlock, Bottleneck]], 227 | planes: int, 228 | blocks: int, 229 | stride: int = 1, 230 | dilate: bool = False, 231 | ) -> nn.Sequential: 232 | norm_layer = self._norm_layer 233 | downsample = None 234 | previous_dilation = self.dilation 235 | if dilate: 236 | self.dilation *= stride 237 | stride = 1 238 | if stride != 1 or self.inplanes != planes * block.expansion: 239 | downsample = nn.Sequential( 240 | conv1x1(self.inplanes, planes * block.expansion, stride), 241 | norm_layer(planes * block.expansion), 242 | ) 243 | 244 | layers = [] 245 | layers.append( 246 | block( 247 | self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer 248 | ) 249 | ) 250 | self.inplanes = planes * block.expansion 251 | for _ in range(1, blocks): 252 | layers.append( 253 | block( 254 | self.inplanes, 255 | planes, 256 | groups=self.groups, 257 | base_width=self.base_width, 258 | dilation=self.dilation, 259 | norm_layer=norm_layer, 260 | ) 261 | ) 262 | 263 | return nn.Sequential(*layers) 264 | 265 | def _forward_impl(self, x: Tensor) -> Tensor: 266 | # See note [TorchScript super()] 267 | x = self.conv1(x) 268 | x = self.bn1(x) 269 | x = self.relu(x) 270 | x = self.maxpool(x) 271 | 272 | x = self.layer1(x) 273 | x = self.layer2(x) 274 | x = self.layer3(x) 275 | x = self.layer4(x) 276 | 277 | x = self.avgpool(x) 278 | x = torch.flatten(x, 1) 279 | x = self.fc(x) 280 | 281 | return x 282 | 283 | def forward(self, x: Tensor) -> Tensor: 284 | return self._forward_impl(x) 285 | 286 | 287 | def _resnet( 288 | block: Type[Union[BasicBlock, Bottleneck]], 289 | layers: List[int], 290 | # weights: Optional[WeightsEnum], 291 | progress: bool, 292 | **kwargs: Any, 293 | ) -> ResNet: 294 | # if weights is not None: 295 | # _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) 296 | 297 | model = ResNet(block, layers, **kwargs) 298 | 299 | # if weights is not None: 300 | # model.load_state_dict(weights.get_state_dict(progress=progress)) 301 | 302 | return model 303 | 304 | 305 | 306 | 307 | def resnet18(*, weights = None, progress: bool = True, **kwargs: Any) -> ResNet: 308 | """ResNet-18 from `Deep Residual Learning for Image Recognition `__. 309 | 310 | """ 311 | 312 | return _resnet(BasicBlock, [2, 2, 2, 2], progress, **kwargs) 313 | 314 | 315 | 316 | def resnet34(*, weights = None, progress: bool = True, **kwargs: Any) -> ResNet: 317 | """ResNet-34 from `Deep Residual Learning for Image Recognition `__. 318 | 319 | """ 320 | 321 | return _resnet(BasicBlock, [3, 4, 6, 3], progress, **kwargs) 322 | 323 | 324 | def resnet50(*, weights = None, progress: bool = True, **kwargs: Any) -> ResNet: 325 | """ResNet-50 from `Deep Residual Learning for Image Recognition `__. 326 | 327 | .. note:: 328 | The bottleneck of TorchVision places the stride for downsampling to the second 3x3 329 | convolution while the original paper places it to the first 1x1 convolution. 330 | This variant improves the accuracy and is known as `ResNet V1.5 331 | `_. 332 | 333 | """ 334 | 335 | return _resnet(Bottleneck, [3, 4, 6, 3], progress, **kwargs) 336 | 337 | 338 | def resnet101(*, weights = None, progress: bool = True, **kwargs: Any) -> ResNet: 339 | """ResNet-101 from `Deep Residual Learning for Image Recognition `__. 340 | 341 | .. note:: 342 | The bottleneck of TorchVision places the stride for downsampling to the second 3x3 343 | convolution while the original paper places it to the first 1x1 convolution. 344 | This variant improves the accuracy and is known as `ResNet V1.5 345 | `_. 346 | 347 | """ 348 | 349 | return _resnet(Bottleneck, [3, 4, 23, 3], progress, **kwargs) 350 | 351 | 352 | def resnet152(*, weights = None, progress: bool = True, **kwargs: Any) -> ResNet: 353 | """ResNet-152 from `Deep Residual Learning for Image Recognition `__. 354 | 355 | .. note:: 356 | The bottleneck of TorchVision places the stride for downsampling to the second 3x3 357 | convolution while the original paper places it to the first 1x1 convolution. 358 | This variant improves the accuracy and is known as `ResNet V1.5 359 | `_. 360 | """ 361 | 362 | return _resnet(Bottleneck, [3, 8, 36, 3], progress, **kwargs) 363 | -------------------------------------------------------------------------------- /wide_resnet_flc.py: -------------------------------------------------------------------------------- 1 | """ Generic Class for Wide ResNet with FLC Pooling 2 | Based on code from https://github.com/yaodongyu/TRADES """ 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from flc_pooling import FLC_Pooling 9 | 10 | 11 | class BasicBlock(nn.Module): 12 | def __init__(self, in_planes, out_planes, stride, dropRate=0.0, cutoff=[8,24,4,12,2,6]): 13 | super(BasicBlock, self).__init__() 14 | self.bn1 = nn.BatchNorm2d(in_planes) 15 | self.relu1 = nn.ReLU(inplace=True) 16 | if stride == 1: 17 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, 18 | padding=1, bias=False) 19 | else: 20 | self.conv1 = nn.Sequential( 21 | FLC_Pooling(), 22 | nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, 23 | padding=1, bias=False) 24 | ) 25 | self.bn2 = nn.BatchNorm2d(out_planes) 26 | self.relu2 = nn.ReLU(inplace=True) 27 | self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, 28 | padding=1, bias=False) 29 | self.droprate = dropRate 30 | self.equalInOut = (in_planes == out_planes) 31 | 32 | if not self.equalInOut: 33 | if stride == 1: 34 | self.convShortcut = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 35 | else: 36 | self.convShortcut = nn.Sequential( 37 | FLC_Pooling(), 38 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) 39 | ) 40 | else: 41 | self.convShortcut = None 42 | 43 | def forward(self, x): 44 | if not self.equalInOut: 45 | x = self.relu1(self.bn1(x)) 46 | else: 47 | out = self.relu1(self.bn1(x)) 48 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 49 | if self.droprate > 0: 50 | out = F.dropout(out, p=self.droprate, training=self.training) 51 | out = self.conv2(out) 52 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 53 | 54 | 55 | class NetworkBlock(nn.Module): 56 | def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0): 57 | super(NetworkBlock, self).__init__() 58 | self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate) 59 | 60 | def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate): 61 | layers = [] 62 | for i in range(int(nb_layers)): 63 | layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate)) 64 | return nn.Sequential(*layers) 65 | 66 | def forward(self, x): 67 | return self.layer(x) 68 | 69 | 70 | class WideResNet(nn.Module): 71 | def __init__(self, depth=28, num_classes=10, widen_factor=10, sub_block1=False, dropRate=0.0, bias_last=True): 72 | super(WideResNet, self).__init__() 73 | nChannels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 74 | assert ((depth - 4) % 6 == 0) 75 | n = (depth - 4) / 6 76 | block = BasicBlock 77 | # 1st conv before any network block 78 | self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, 79 | padding=1, bias=False) 80 | # 1st block 81 | self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 82 | if sub_block1: 83 | # 1st sub-block 84 | self.sub_block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate) 85 | # 2nd block 86 | self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate) 87 | # 3rd block 88 | self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate) 89 | # global average pooling and classifier 90 | self.bn1 = nn.BatchNorm2d(nChannels[3]) 91 | self.relu = nn.ReLU(inplace=True) 92 | self.fc = nn.Linear(nChannels[3], num_classes, bias=bias_last) 93 | self.nChannels = nChannels[3] 94 | 95 | for m in self.modules(): 96 | if isinstance(m, nn.Conv2d): 97 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 98 | m.weight.data.normal_(0, math.sqrt(2. / n)) 99 | elif isinstance(m, nn.BatchNorm2d): 100 | m.weight.data.fill_(1) 101 | m.bias.data.zero_() 102 | elif isinstance(m, nn.Linear) and not m.bias is None: 103 | m.bias.data.zero_() 104 | 105 | def forward(self, x): 106 | out = self.conv1(x) 107 | out = self.block1(out) 108 | out = self.block2(out) 109 | out = self.block3(out) 110 | out = self.relu(self.bn1(out)) 111 | out = F.avg_pool2d(out, 8) 112 | out = out.view(-1, self.nChannels) 113 | return self.fc(out) 114 | 115 | 116 | class WRN_normalized(WideResNet): 117 | def __init__(self, device='cuda'): 118 | super(WRN_normalized, self).__init__() 119 | self.mu = torch.Tensor([0.4914, 0.4822, 0.4465]).float().view(3, 1, 1).to(device) 120 | self.sigma = torch.Tensor([0.2471, 0.2435, 0.2616]).float().view(3, 1, 1).to(device) 121 | 122 | def forward(self, x): 123 | x = (x - self.mu) / self.sigma 124 | return super(WRN_normalized, self).forward(x) 125 | 126 | def WideResNet2810_normalized(device='cuda'): 127 | return WRN_normalized(device=device) 128 | 129 | def WideResNet2810(device='cuda'): 130 | return WideResNet() 131 | --------------------------------------------------------------------------------